fix(llm): add streaming retry with exponential backoff
- Retry failed streams up to 3 times with exp backoff (8s min, 64s max) - Reset chunks on failure and retry full request - Use litellm._should_retry() for retryable error detection - Switch to async acompletion() for streaming - Refactor generate() into smaller focused methods
This commit is contained in:
192
strix/llm/llm.py
192
strix/llm/llm.py
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
@@ -24,6 +25,23 @@ from strix.prompts import load_prompt_modules
|
||||
from strix.tools import get_tools_prompt
|
||||
|
||||
|
||||
MAX_RETRIES = 5
|
||||
RETRY_MULTIPLIER = 8
|
||||
RETRY_MIN = 8
|
||||
RETRY_MAX = 64
|
||||
|
||||
|
||||
def _should_retry(exception: Exception) -> bool:
|
||||
status_code = None
|
||||
if hasattr(exception, "status_code"):
|
||||
status_code = exception.status_code
|
||||
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
||||
status_code = exception.response.status_code
|
||||
if status_code is not None:
|
||||
return bool(litellm._should_retry(status_code))
|
||||
return True
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
litellm.drop_params = True
|
||||
@@ -272,12 +290,7 @@ class LLM:
|
||||
|
||||
return cached_messages
|
||||
|
||||
async def generate( # noqa: PLR0912, PLR0915
|
||||
self,
|
||||
conversation_history: list[dict[str, Any]],
|
||||
scan_id: str | None = None,
|
||||
step_number: int = 1,
|
||||
) -> AsyncIterator[LLMResponse]:
|
||||
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
messages = [{"role": "system", "content": self.system_prompt}]
|
||||
|
||||
identity_message = self._build_identity_message()
|
||||
@@ -290,93 +303,104 @@ class LLM:
|
||||
conversation_history.extend(compressed_history)
|
||||
messages.extend(compressed_history)
|
||||
|
||||
cached_messages = self._prepare_cached_messages(messages)
|
||||
return self._prepare_cached_messages(messages)
|
||||
|
||||
try:
|
||||
accumulated_content = ""
|
||||
chunks: list[Any] = []
|
||||
async def _stream_and_accumulate(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
scan_id: str | None,
|
||||
step_number: int,
|
||||
) -> AsyncIterator[LLMResponse]:
|
||||
accumulated_content = ""
|
||||
chunks: list[Any] = []
|
||||
|
||||
async for chunk in self._stream_request(cached_messages):
|
||||
chunks.append(chunk)
|
||||
delta = self._extract_chunk_delta(chunk)
|
||||
if delta:
|
||||
accumulated_content += delta
|
||||
async for chunk in self._stream_request(messages):
|
||||
chunks.append(chunk)
|
||||
delta = self._extract_chunk_delta(chunk)
|
||||
if delta:
|
||||
accumulated_content += delta
|
||||
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=accumulated_content,
|
||||
tool_invocations=None,
|
||||
)
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=accumulated_content,
|
||||
tool_invocations=None,
|
||||
)
|
||||
|
||||
if chunks:
|
||||
complete_response = stream_chunk_builder(chunks)
|
||||
self._update_usage_stats(complete_response)
|
||||
if chunks:
|
||||
complete_response = stream_chunk_builder(chunks)
|
||||
self._update_usage_stats(complete_response)
|
||||
|
||||
accumulated_content = _truncate_to_first_function(accumulated_content)
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
accumulated_content = _truncate_to_first_function(accumulated_content)
|
||||
if "</function>" in accumulated_content:
|
||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||
accumulated_content = accumulated_content[:function_end]
|
||||
|
||||
tool_invocations = parse_tool_invocations(accumulated_content)
|
||||
tool_invocations = parse_tool_invocations(accumulated_content)
|
||||
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=accumulated_content,
|
||||
tool_invocations=tool_invocations if tool_invocations else None,
|
||||
)
|
||||
yield LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content=accumulated_content,
|
||||
tool_invocations=tool_invocations if tool_invocations else None,
|
||||
)
|
||||
|
||||
except litellm.RateLimitError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e
|
||||
except litellm.AuthenticationError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e
|
||||
except litellm.NotFoundError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e
|
||||
except litellm.ContextWindowExceededError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e
|
||||
except litellm.ContentPolicyViolationError as e:
|
||||
raise LLMRequestFailedError(
|
||||
"LLM request failed: Content policy violation", str(e)
|
||||
) from e
|
||||
except litellm.ServiceUnavailableError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e
|
||||
except litellm.Timeout as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e
|
||||
except litellm.UnprocessableEntityError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e
|
||||
except litellm.InternalServerError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e
|
||||
except litellm.APIConnectionError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e
|
||||
except litellm.UnsupportedParamsError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e
|
||||
except litellm.BudgetExceededError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e
|
||||
except litellm.APIResponseValidationError as e:
|
||||
raise LLMRequestFailedError(
|
||||
"LLM request failed: Response validation error", str(e)
|
||||
) from e
|
||||
except litellm.JSONSchemaValidationError as e:
|
||||
raise LLMRequestFailedError(
|
||||
"LLM request failed: JSON schema validation error", str(e)
|
||||
) from e
|
||||
except litellm.InvalidRequestError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e
|
||||
except litellm.BadRequestError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e
|
||||
except litellm.APIError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e
|
||||
except litellm.OpenAIError as e:
|
||||
raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e
|
||||
except Exception as e:
|
||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||
def _raise_llm_error(self, e: Exception) -> None:
|
||||
error_map: list[tuple[type, str]] = [
|
||||
(litellm.RateLimitError, "Rate limit exceeded"),
|
||||
(litellm.AuthenticationError, "Invalid API key"),
|
||||
(litellm.NotFoundError, "Model not found"),
|
||||
(litellm.ContextWindowExceededError, "Context too long"),
|
||||
(litellm.ContentPolicyViolationError, "Content policy violation"),
|
||||
(litellm.ServiceUnavailableError, "Service unavailable"),
|
||||
(litellm.Timeout, "Request timed out"),
|
||||
(litellm.UnprocessableEntityError, "Unprocessable entity"),
|
||||
(litellm.InternalServerError, "Internal server error"),
|
||||
(litellm.APIConnectionError, "Connection error"),
|
||||
(litellm.UnsupportedParamsError, "Unsupported parameters"),
|
||||
(litellm.BudgetExceededError, "Budget exceeded"),
|
||||
(litellm.APIResponseValidationError, "Response validation error"),
|
||||
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
|
||||
(litellm.InvalidRequestError, "Invalid request"),
|
||||
(litellm.BadRequestError, "Bad request"),
|
||||
(litellm.APIError, "API error"),
|
||||
(litellm.OpenAIError, "OpenAI error"),
|
||||
]
|
||||
for error_type, message in error_map:
|
||||
if isinstance(e, error_type):
|
||||
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
|
||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
conversation_history: list[dict[str, Any]],
|
||||
scan_id: str | None = None,
|
||||
step_number: int = 1,
|
||||
) -> AsyncIterator[LLMResponse]:
|
||||
messages = self._prepare_messages(conversation_history)
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
|
||||
yield response
|
||||
return # noqa: TRY300
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_error = e
|
||||
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
|
||||
break
|
||||
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
|
||||
wait_time = max(RETRY_MIN, wait_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
if last_error:
|
||||
self._raise_llm_error(last_error)
|
||||
|
||||
def _extract_chunk_delta(self, chunk: Any) -> str:
|
||||
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
||||
|
||||
@@ -1,31 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm import acompletion
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def should_retry_exception(exception: Exception) -> bool:
|
||||
status_code = None
|
||||
|
||||
if hasattr(exception, "status_code"):
|
||||
status_code = exception.status_code
|
||||
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
||||
status_code = exception.response.status_code
|
||||
|
||||
if status_code is not None:
|
||||
return bool(litellm._should_retry(status_code))
|
||||
return True
|
||||
|
||||
|
||||
class LLMRequestQueue:
|
||||
@@ -60,22 +41,17 @@ class LLMRequestQueue:
|
||||
if sleep_needed > 0:
|
||||
await asyncio.sleep(sleep_needed)
|
||||
|
||||
async for chunk in self._reliable_stream_request(completion_args):
|
||||
async for chunk in self._stream_request(completion_args):
|
||||
yield chunk
|
||||
finally:
|
||||
self._semaphore.release()
|
||||
|
||||
@retry( # type: ignore[misc]
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=8, min=8, max=64),
|
||||
retry=retry_if_exception(should_retry_exception),
|
||||
reraise=True,
|
||||
)
|
||||
async def _reliable_stream_request(
|
||||
async def _stream_request(
|
||||
self, completion_args: dict[str, Any]
|
||||
) -> AsyncIterator[ModelResponseStream]:
|
||||
response = await asyncio.to_thread(completion, **completion_args, stream=True)
|
||||
for chunk in response:
|
||||
response = await acompletion(**completion_args, stream=True)
|
||||
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user