From 0954ac208f51e6083b8b78d3a62bc492c89b179a Mon Sep 17 00:00:00 2001 From: 0xallam Date: Mon, 5 Jan 2026 13:40:24 -0800 Subject: [PATCH] fix(llm): add streaming retry with exponential backoff - Retry failed streams up to 3 times with exp backoff (8s min, 64s max) - Reset chunks on failure and retry full request - Use litellm._should_retry() for retryable error detection - Switch to async acompletion() for streaming - Refactor generate() into smaller focused methods --- strix/llm/llm.py | 192 +++++++++++++++++++++---------------- strix/llm/request_queue.py | 36 ++----- 2 files changed, 114 insertions(+), 114 deletions(-) diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 558a0c5..b69f1f8 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -1,3 +1,4 @@ +import asyncio import logging import os from collections.abc import AsyncIterator @@ -24,6 +25,23 @@ from strix.prompts import load_prompt_modules from strix.tools import get_tools_prompt +MAX_RETRIES = 5 +RETRY_MULTIPLIER = 8 +RETRY_MIN = 8 +RETRY_MAX = 64 + + +def _should_retry(exception: Exception) -> bool: + status_code = None + if hasattr(exception, "status_code"): + status_code = exception.status_code + elif hasattr(exception, "response") and hasattr(exception.response, "status_code"): + status_code = exception.response.status_code + if status_code is not None: + return bool(litellm._should_retry(status_code)) + return True + + logger = logging.getLogger(__name__) litellm.drop_params = True @@ -272,12 +290,7 @@ class LLM: return cached_messages - async def generate( # noqa: PLR0912, PLR0915 - self, - conversation_history: list[dict[str, Any]], - scan_id: str | None = None, - step_number: int = 1, - ) -> AsyncIterator[LLMResponse]: + def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]: messages = [{"role": "system", "content": self.system_prompt}] identity_message = self._build_identity_message() @@ -290,93 +303,104 @@ class LLM: conversation_history.extend(compressed_history) messages.extend(compressed_history) - cached_messages = self._prepare_cached_messages(messages) + return self._prepare_cached_messages(messages) - try: - accumulated_content = "" - chunks: list[Any] = [] + async def _stream_and_accumulate( + self, + messages: list[dict[str, Any]], + scan_id: str | None, + step_number: int, + ) -> AsyncIterator[LLMResponse]: + accumulated_content = "" + chunks: list[Any] = [] - async for chunk in self._stream_request(cached_messages): - chunks.append(chunk) - delta = self._extract_chunk_delta(chunk) - if delta: - accumulated_content += delta + async for chunk in self._stream_request(messages): + chunks.append(chunk) + delta = self._extract_chunk_delta(chunk) + if delta: + accumulated_content += delta - if "" in accumulated_content: - function_end = accumulated_content.find("") + len("") - accumulated_content = accumulated_content[:function_end] + if "" in accumulated_content: + function_end = accumulated_content.find("") + len("") + accumulated_content = accumulated_content[:function_end] - yield LLMResponse( - scan_id=scan_id, - step_number=step_number, - role=StepRole.AGENT, - content=accumulated_content, - tool_invocations=None, - ) + yield LLMResponse( + scan_id=scan_id, + step_number=step_number, + role=StepRole.AGENT, + content=accumulated_content, + tool_invocations=None, + ) - if chunks: - complete_response = stream_chunk_builder(chunks) - self._update_usage_stats(complete_response) + if chunks: + complete_response = stream_chunk_builder(chunks) + self._update_usage_stats(complete_response) - accumulated_content = _truncate_to_first_function(accumulated_content) - if "" in accumulated_content: - function_end = accumulated_content.find("") + len("") - accumulated_content = accumulated_content[:function_end] + accumulated_content = _truncate_to_first_function(accumulated_content) + if "" in accumulated_content: + function_end = accumulated_content.find("") + len("") + accumulated_content = accumulated_content[:function_end] - tool_invocations = parse_tool_invocations(accumulated_content) + tool_invocations = parse_tool_invocations(accumulated_content) - yield LLMResponse( - scan_id=scan_id, - step_number=step_number, - role=StepRole.AGENT, - content=accumulated_content, - tool_invocations=tool_invocations if tool_invocations else None, - ) + yield LLMResponse( + scan_id=scan_id, + step_number=step_number, + role=StepRole.AGENT, + content=accumulated_content, + tool_invocations=tool_invocations if tool_invocations else None, + ) - except litellm.RateLimitError as e: - raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e - except litellm.AuthenticationError as e: - raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e - except litellm.NotFoundError as e: - raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e - except litellm.ContextWindowExceededError as e: - raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e - except litellm.ContentPolicyViolationError as e: - raise LLMRequestFailedError( - "LLM request failed: Content policy violation", str(e) - ) from e - except litellm.ServiceUnavailableError as e: - raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e - except litellm.Timeout as e: - raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e - except litellm.UnprocessableEntityError as e: - raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e - except litellm.InternalServerError as e: - raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e - except litellm.APIConnectionError as e: - raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e - except litellm.UnsupportedParamsError as e: - raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e - except litellm.BudgetExceededError as e: - raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e - except litellm.APIResponseValidationError as e: - raise LLMRequestFailedError( - "LLM request failed: Response validation error", str(e) - ) from e - except litellm.JSONSchemaValidationError as e: - raise LLMRequestFailedError( - "LLM request failed: JSON schema validation error", str(e) - ) from e - except litellm.InvalidRequestError as e: - raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e - except litellm.BadRequestError as e: - raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e - except litellm.APIError as e: - raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e - except litellm.OpenAIError as e: - raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e - except Exception as e: - raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e + def _raise_llm_error(self, e: Exception) -> None: + error_map: list[tuple[type, str]] = [ + (litellm.RateLimitError, "Rate limit exceeded"), + (litellm.AuthenticationError, "Invalid API key"), + (litellm.NotFoundError, "Model not found"), + (litellm.ContextWindowExceededError, "Context too long"), + (litellm.ContentPolicyViolationError, "Content policy violation"), + (litellm.ServiceUnavailableError, "Service unavailable"), + (litellm.Timeout, "Request timed out"), + (litellm.UnprocessableEntityError, "Unprocessable entity"), + (litellm.InternalServerError, "Internal server error"), + (litellm.APIConnectionError, "Connection error"), + (litellm.UnsupportedParamsError, "Unsupported parameters"), + (litellm.BudgetExceededError, "Budget exceeded"), + (litellm.APIResponseValidationError, "Response validation error"), + (litellm.JSONSchemaValidationError, "JSON schema validation error"), + (litellm.InvalidRequestError, "Invalid request"), + (litellm.BadRequestError, "Bad request"), + (litellm.APIError, "API error"), + (litellm.OpenAIError, "OpenAI error"), + ] + for error_type, message in error_map: + if isinstance(e, error_type): + raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e + raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e + + async def generate( + self, + conversation_history: list[dict[str, Any]], + scan_id: str | None = None, + step_number: int = 1, + ) -> AsyncIterator[LLMResponse]: + messages = self._prepare_messages(conversation_history) + + last_error: Exception | None = None + for attempt in range(MAX_RETRIES): + try: + async for response in self._stream_and_accumulate(messages, scan_id, step_number): + yield response + return # noqa: TRY300 + except Exception as e: # noqa: BLE001 + last_error = e + if not _should_retry(e) or attempt == MAX_RETRIES - 1: + break + wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt)) + wait_time = max(RETRY_MIN, wait_time) + await asyncio.sleep(wait_time) + + if last_error: + self._raise_llm_error(last_error) def _extract_chunk_delta(self, chunk: Any) -> str: if chunk.choices and hasattr(chunk.choices[0], "delta"): diff --git a/strix/llm/request_queue.py b/strix/llm/request_queue.py index 4760196..0b68737 100644 --- a/strix/llm/request_queue.py +++ b/strix/llm/request_queue.py @@ -1,31 +1,12 @@ import asyncio -import logging import os import threading import time from collections.abc import AsyncIterator from typing import Any -import litellm -from litellm import completion +from litellm import acompletion from litellm.types.utils import ModelResponseStream -from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential - - -logger = logging.getLogger(__name__) - - -def should_retry_exception(exception: Exception) -> bool: - status_code = None - - if hasattr(exception, "status_code"): - status_code = exception.status_code - elif hasattr(exception, "response") and hasattr(exception.response, "status_code"): - status_code = exception.response.status_code - - if status_code is not None: - return bool(litellm._should_retry(status_code)) - return True class LLMRequestQueue: @@ -60,22 +41,17 @@ class LLMRequestQueue: if sleep_needed > 0: await asyncio.sleep(sleep_needed) - async for chunk in self._reliable_stream_request(completion_args): + async for chunk in self._stream_request(completion_args): yield chunk finally: self._semaphore.release() - @retry( # type: ignore[misc] - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=8, min=8, max=64), - retry=retry_if_exception(should_retry_exception), - reraise=True, - ) - async def _reliable_stream_request( + async def _stream_request( self, completion_args: dict[str, Any] ) -> AsyncIterator[ModelResponseStream]: - response = await asyncio.to_thread(completion, **completion_args, stream=True) - for chunk in response: + response = await acompletion(**completion_args, stream=True) + + async for chunk in response: yield chunk