fix(llm): add streaming retry with exponential backoff

- Retry failed streams up to 3 times with exp backoff (8s min, 64s max)
- Reset chunks on failure and retry full request
- Use litellm._should_retry() for retryable error detection
- Switch to async acompletion() for streaming
- Refactor generate() into smaller focused methods
This commit is contained in:
0xallam
2026-01-05 13:40:24 -08:00
committed by Ahmed Allam
parent a6dcb7756e
commit 0954ac208f
2 changed files with 114 additions and 114 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import logging import logging
import os import os
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
@@ -24,6 +25,23 @@ from strix.prompts import load_prompt_modules
from strix.tools import get_tools_prompt from strix.tools import get_tools_prompt
MAX_RETRIES = 5
RETRY_MULTIPLIER = 8
RETRY_MIN = 8
RETRY_MAX = 64
def _should_retry(exception: Exception) -> bool:
status_code = None
if hasattr(exception, "status_code"):
status_code = exception.status_code
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
status_code = exception.response.status_code
if status_code is not None:
return bool(litellm._should_retry(status_code))
return True
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
litellm.drop_params = True litellm.drop_params = True
@@ -272,12 +290,7 @@ class LLM:
return cached_messages return cached_messages
async def generate( # noqa: PLR0912, PLR0915 def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
self,
conversation_history: list[dict[str, Any]],
scan_id: str | None = None,
step_number: int = 1,
) -> AsyncIterator[LLMResponse]:
messages = [{"role": "system", "content": self.system_prompt}] messages = [{"role": "system", "content": self.system_prompt}]
identity_message = self._build_identity_message() identity_message = self._build_identity_message()
@@ -290,93 +303,104 @@ class LLM:
conversation_history.extend(compressed_history) conversation_history.extend(compressed_history)
messages.extend(compressed_history) messages.extend(compressed_history)
cached_messages = self._prepare_cached_messages(messages) return self._prepare_cached_messages(messages)
try: async def _stream_and_accumulate(
accumulated_content = "" self,
chunks: list[Any] = [] messages: list[dict[str, Any]],
scan_id: str | None,
step_number: int,
) -> AsyncIterator[LLMResponse]:
accumulated_content = ""
chunks: list[Any] = []
async for chunk in self._stream_request(cached_messages): async for chunk in self._stream_request(messages):
chunks.append(chunk) chunks.append(chunk)
delta = self._extract_chunk_delta(chunk) delta = self._extract_chunk_delta(chunk)
if delta: if delta:
accumulated_content += delta accumulated_content += delta
if "</function>" in accumulated_content: if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>") function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end] accumulated_content = accumulated_content[:function_end]
yield LLMResponse( yield LLMResponse(
scan_id=scan_id, scan_id=scan_id,
step_number=step_number, step_number=step_number,
role=StepRole.AGENT, role=StepRole.AGENT,
content=accumulated_content, content=accumulated_content,
tool_invocations=None, tool_invocations=None,
) )
if chunks: if chunks:
complete_response = stream_chunk_builder(chunks) complete_response = stream_chunk_builder(chunks)
self._update_usage_stats(complete_response) self._update_usage_stats(complete_response)
accumulated_content = _truncate_to_first_function(accumulated_content) accumulated_content = _truncate_to_first_function(accumulated_content)
if "</function>" in accumulated_content: if "</function>" in accumulated_content:
function_end = accumulated_content.find("</function>") + len("</function>") function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end] accumulated_content = accumulated_content[:function_end]
tool_invocations = parse_tool_invocations(accumulated_content) tool_invocations = parse_tool_invocations(accumulated_content)
yield LLMResponse( yield LLMResponse(
scan_id=scan_id, scan_id=scan_id,
step_number=step_number, step_number=step_number,
role=StepRole.AGENT, role=StepRole.AGENT,
content=accumulated_content, content=accumulated_content,
tool_invocations=tool_invocations if tool_invocations else None, tool_invocations=tool_invocations if tool_invocations else None,
) )
except litellm.RateLimitError as e: def _raise_llm_error(self, e: Exception) -> None:
raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e error_map: list[tuple[type, str]] = [
except litellm.AuthenticationError as e: (litellm.RateLimitError, "Rate limit exceeded"),
raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e (litellm.AuthenticationError, "Invalid API key"),
except litellm.NotFoundError as e: (litellm.NotFoundError, "Model not found"),
raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e (litellm.ContextWindowExceededError, "Context too long"),
except litellm.ContextWindowExceededError as e: (litellm.ContentPolicyViolationError, "Content policy violation"),
raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e (litellm.ServiceUnavailableError, "Service unavailable"),
except litellm.ContentPolicyViolationError as e: (litellm.Timeout, "Request timed out"),
raise LLMRequestFailedError( (litellm.UnprocessableEntityError, "Unprocessable entity"),
"LLM request failed: Content policy violation", str(e) (litellm.InternalServerError, "Internal server error"),
) from e (litellm.APIConnectionError, "Connection error"),
except litellm.ServiceUnavailableError as e: (litellm.UnsupportedParamsError, "Unsupported parameters"),
raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e (litellm.BudgetExceededError, "Budget exceeded"),
except litellm.Timeout as e: (litellm.APIResponseValidationError, "Response validation error"),
raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e (litellm.JSONSchemaValidationError, "JSON schema validation error"),
except litellm.UnprocessableEntityError as e: (litellm.InvalidRequestError, "Invalid request"),
raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e (litellm.BadRequestError, "Bad request"),
except litellm.InternalServerError as e: (litellm.APIError, "API error"),
raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e (litellm.OpenAIError, "OpenAI error"),
except litellm.APIConnectionError as e: ]
raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e for error_type, message in error_map:
except litellm.UnsupportedParamsError as e: if isinstance(e, error_type):
raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
except litellm.BudgetExceededError as e: raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e
except litellm.APIResponseValidationError as e: async def generate(
raise LLMRequestFailedError( self,
"LLM request failed: Response validation error", str(e) conversation_history: list[dict[str, Any]],
) from e scan_id: str | None = None,
except litellm.JSONSchemaValidationError as e: step_number: int = 1,
raise LLMRequestFailedError( ) -> AsyncIterator[LLMResponse]:
"LLM request failed: JSON schema validation error", str(e) messages = self._prepare_messages(conversation_history)
) from e
except litellm.InvalidRequestError as e: last_error: Exception | None = None
raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e for attempt in range(MAX_RETRIES):
except litellm.BadRequestError as e: try:
raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e async for response in self._stream_and_accumulate(messages, scan_id, step_number):
except litellm.APIError as e: yield response
raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e return # noqa: TRY300
except litellm.OpenAIError as e: except Exception as e: # noqa: BLE001
raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e last_error = e
except Exception as e: if not _should_retry(e) or attempt == MAX_RETRIES - 1:
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e break
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
wait_time = max(RETRY_MIN, wait_time)
await asyncio.sleep(wait_time)
if last_error:
self._raise_llm_error(last_error)
def _extract_chunk_delta(self, chunk: Any) -> str: def _extract_chunk_delta(self, chunk: Any) -> str:
if chunk.choices and hasattr(chunk.choices[0], "delta"): if chunk.choices and hasattr(chunk.choices[0], "delta"):

View File

@@ -1,31 +1,12 @@
import asyncio import asyncio
import logging
import os import os
import threading import threading
import time import time
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any from typing import Any
import litellm from litellm import acompletion
from litellm import completion
from litellm.types.utils import ModelResponseStream from litellm.types.utils import ModelResponseStream
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
logger = logging.getLogger(__name__)
def should_retry_exception(exception: Exception) -> bool:
status_code = None
if hasattr(exception, "status_code"):
status_code = exception.status_code
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
status_code = exception.response.status_code
if status_code is not None:
return bool(litellm._should_retry(status_code))
return True
class LLMRequestQueue: class LLMRequestQueue:
@@ -60,22 +41,17 @@ class LLMRequestQueue:
if sleep_needed > 0: if sleep_needed > 0:
await asyncio.sleep(sleep_needed) await asyncio.sleep(sleep_needed)
async for chunk in self._reliable_stream_request(completion_args): async for chunk in self._stream_request(completion_args):
yield chunk yield chunk
finally: finally:
self._semaphore.release() self._semaphore.release()
@retry( # type: ignore[misc] async def _stream_request(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=8, min=8, max=64),
retry=retry_if_exception(should_retry_exception),
reraise=True,
)
async def _reliable_stream_request(
self, completion_args: dict[str, Any] self, completion_args: dict[str, Any]
) -> AsyncIterator[ModelResponseStream]: ) -> AsyncIterator[ModelResponseStream]:
response = await asyncio.to_thread(completion, **completion_args, stream=True) response = await acompletion(**completion_args, stream=True)
for chunk in response:
async for chunk in response:
yield chunk yield chunk