fix(llm): add streaming retry with exponential backoff
- Retry failed streams up to 3 times with exp backoff (8s min, 64s max) - Reset chunks on failure and retry full request - Use litellm._should_retry() for retryable error detection - Switch to async acompletion() for streaming - Refactor generate() into smaller focused methods
This commit is contained in:
192
strix/llm/llm.py
192
strix/llm/llm.py
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
@@ -24,6 +25,23 @@ from strix.prompts import load_prompt_modules
|
|||||||
from strix.tools import get_tools_prompt
|
from strix.tools import get_tools_prompt
|
||||||
|
|
||||||
|
|
||||||
|
MAX_RETRIES = 5
|
||||||
|
RETRY_MULTIPLIER = 8
|
||||||
|
RETRY_MIN = 8
|
||||||
|
RETRY_MAX = 64
|
||||||
|
|
||||||
|
|
||||||
|
def _should_retry(exception: Exception) -> bool:
|
||||||
|
status_code = None
|
||||||
|
if hasattr(exception, "status_code"):
|
||||||
|
status_code = exception.status_code
|
||||||
|
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
||||||
|
status_code = exception.response.status_code
|
||||||
|
if status_code is not None:
|
||||||
|
return bool(litellm._should_retry(status_code))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
@@ -272,12 +290,7 @@ class LLM:
|
|||||||
|
|
||||||
return cached_messages
|
return cached_messages
|
||||||
|
|
||||||
async def generate( # noqa: PLR0912, PLR0915
|
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
self,
|
|
||||||
conversation_history: list[dict[str, Any]],
|
|
||||||
scan_id: str | None = None,
|
|
||||||
step_number: int = 1,
|
|
||||||
) -> AsyncIterator[LLMResponse]:
|
|
||||||
messages = [{"role": "system", "content": self.system_prompt}]
|
messages = [{"role": "system", "content": self.system_prompt}]
|
||||||
|
|
||||||
identity_message = self._build_identity_message()
|
identity_message = self._build_identity_message()
|
||||||
@@ -290,93 +303,104 @@ class LLM:
|
|||||||
conversation_history.extend(compressed_history)
|
conversation_history.extend(compressed_history)
|
||||||
messages.extend(compressed_history)
|
messages.extend(compressed_history)
|
||||||
|
|
||||||
cached_messages = self._prepare_cached_messages(messages)
|
return self._prepare_cached_messages(messages)
|
||||||
|
|
||||||
try:
|
async def _stream_and_accumulate(
|
||||||
accumulated_content = ""
|
self,
|
||||||
chunks: list[Any] = []
|
messages: list[dict[str, Any]],
|
||||||
|
scan_id: str | None,
|
||||||
|
step_number: int,
|
||||||
|
) -> AsyncIterator[LLMResponse]:
|
||||||
|
accumulated_content = ""
|
||||||
|
chunks: list[Any] = []
|
||||||
|
|
||||||
async for chunk in self._stream_request(cached_messages):
|
async for chunk in self._stream_request(messages):
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
delta = self._extract_chunk_delta(chunk)
|
delta = self._extract_chunk_delta(chunk)
|
||||||
if delta:
|
if delta:
|
||||||
accumulated_content += delta
|
accumulated_content += delta
|
||||||
|
|
||||||
if "</function>" in accumulated_content:
|
if "</function>" in accumulated_content:
|
||||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||||
accumulated_content = accumulated_content[:function_end]
|
accumulated_content = accumulated_content[:function_end]
|
||||||
|
|
||||||
yield LLMResponse(
|
yield LLMResponse(
|
||||||
scan_id=scan_id,
|
scan_id=scan_id,
|
||||||
step_number=step_number,
|
step_number=step_number,
|
||||||
role=StepRole.AGENT,
|
role=StepRole.AGENT,
|
||||||
content=accumulated_content,
|
content=accumulated_content,
|
||||||
tool_invocations=None,
|
tool_invocations=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunks:
|
if chunks:
|
||||||
complete_response = stream_chunk_builder(chunks)
|
complete_response = stream_chunk_builder(chunks)
|
||||||
self._update_usage_stats(complete_response)
|
self._update_usage_stats(complete_response)
|
||||||
|
|
||||||
accumulated_content = _truncate_to_first_function(accumulated_content)
|
accumulated_content = _truncate_to_first_function(accumulated_content)
|
||||||
if "</function>" in accumulated_content:
|
if "</function>" in accumulated_content:
|
||||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||||
accumulated_content = accumulated_content[:function_end]
|
accumulated_content = accumulated_content[:function_end]
|
||||||
|
|
||||||
tool_invocations = parse_tool_invocations(accumulated_content)
|
tool_invocations = parse_tool_invocations(accumulated_content)
|
||||||
|
|
||||||
yield LLMResponse(
|
yield LLMResponse(
|
||||||
scan_id=scan_id,
|
scan_id=scan_id,
|
||||||
step_number=step_number,
|
step_number=step_number,
|
||||||
role=StepRole.AGENT,
|
role=StepRole.AGENT,
|
||||||
content=accumulated_content,
|
content=accumulated_content,
|
||||||
tool_invocations=tool_invocations if tool_invocations else None,
|
tool_invocations=tool_invocations if tool_invocations else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
except litellm.RateLimitError as e:
|
def _raise_llm_error(self, e: Exception) -> None:
|
||||||
raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e
|
error_map: list[tuple[type, str]] = [
|
||||||
except litellm.AuthenticationError as e:
|
(litellm.RateLimitError, "Rate limit exceeded"),
|
||||||
raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e
|
(litellm.AuthenticationError, "Invalid API key"),
|
||||||
except litellm.NotFoundError as e:
|
(litellm.NotFoundError, "Model not found"),
|
||||||
raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e
|
(litellm.ContextWindowExceededError, "Context too long"),
|
||||||
except litellm.ContextWindowExceededError as e:
|
(litellm.ContentPolicyViolationError, "Content policy violation"),
|
||||||
raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e
|
(litellm.ServiceUnavailableError, "Service unavailable"),
|
||||||
except litellm.ContentPolicyViolationError as e:
|
(litellm.Timeout, "Request timed out"),
|
||||||
raise LLMRequestFailedError(
|
(litellm.UnprocessableEntityError, "Unprocessable entity"),
|
||||||
"LLM request failed: Content policy violation", str(e)
|
(litellm.InternalServerError, "Internal server error"),
|
||||||
) from e
|
(litellm.APIConnectionError, "Connection error"),
|
||||||
except litellm.ServiceUnavailableError as e:
|
(litellm.UnsupportedParamsError, "Unsupported parameters"),
|
||||||
raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e
|
(litellm.BudgetExceededError, "Budget exceeded"),
|
||||||
except litellm.Timeout as e:
|
(litellm.APIResponseValidationError, "Response validation error"),
|
||||||
raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e
|
(litellm.JSONSchemaValidationError, "JSON schema validation error"),
|
||||||
except litellm.UnprocessableEntityError as e:
|
(litellm.InvalidRequestError, "Invalid request"),
|
||||||
raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e
|
(litellm.BadRequestError, "Bad request"),
|
||||||
except litellm.InternalServerError as e:
|
(litellm.APIError, "API error"),
|
||||||
raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e
|
(litellm.OpenAIError, "OpenAI error"),
|
||||||
except litellm.APIConnectionError as e:
|
]
|
||||||
raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e
|
for error_type, message in error_map:
|
||||||
except litellm.UnsupportedParamsError as e:
|
if isinstance(e, error_type):
|
||||||
raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e
|
raise LLMRequestFailedError(f"LLM request failed: {message}", str(e)) from e
|
||||||
except litellm.BudgetExceededError as e:
|
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
||||||
raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e
|
|
||||||
except litellm.APIResponseValidationError as e:
|
async def generate(
|
||||||
raise LLMRequestFailedError(
|
self,
|
||||||
"LLM request failed: Response validation error", str(e)
|
conversation_history: list[dict[str, Any]],
|
||||||
) from e
|
scan_id: str | None = None,
|
||||||
except litellm.JSONSchemaValidationError as e:
|
step_number: int = 1,
|
||||||
raise LLMRequestFailedError(
|
) -> AsyncIterator[LLMResponse]:
|
||||||
"LLM request failed: JSON schema validation error", str(e)
|
messages = self._prepare_messages(conversation_history)
|
||||||
) from e
|
|
||||||
except litellm.InvalidRequestError as e:
|
last_error: Exception | None = None
|
||||||
raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e
|
for attempt in range(MAX_RETRIES):
|
||||||
except litellm.BadRequestError as e:
|
try:
|
||||||
raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e
|
async for response in self._stream_and_accumulate(messages, scan_id, step_number):
|
||||||
except litellm.APIError as e:
|
yield response
|
||||||
raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e
|
return # noqa: TRY300
|
||||||
except litellm.OpenAIError as e:
|
except Exception as e: # noqa: BLE001
|
||||||
raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e
|
last_error = e
|
||||||
except Exception as e:
|
if not _should_retry(e) or attempt == MAX_RETRIES - 1:
|
||||||
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
break
|
||||||
|
wait_time = min(RETRY_MAX, RETRY_MULTIPLIER * (2**attempt))
|
||||||
|
wait_time = max(RETRY_MIN, wait_time)
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
|
if last_error:
|
||||||
|
self._raise_llm_error(last_error)
|
||||||
|
|
||||||
def _extract_chunk_delta(self, chunk: Any) -> str:
|
def _extract_chunk_delta(self, chunk: Any) -> str:
|
||||||
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
||||||
|
|||||||
@@ -1,31 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
from litellm import acompletion
|
||||||
from litellm import completion
|
|
||||||
from litellm.types.utils import ModelResponseStream
|
from litellm.types.utils import ModelResponseStream
|
||||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def should_retry_exception(exception: Exception) -> bool:
|
|
||||||
status_code = None
|
|
||||||
|
|
||||||
if hasattr(exception, "status_code"):
|
|
||||||
status_code = exception.status_code
|
|
||||||
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
|
||||||
status_code = exception.response.status_code
|
|
||||||
|
|
||||||
if status_code is not None:
|
|
||||||
return bool(litellm._should_retry(status_code))
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestQueue:
|
class LLMRequestQueue:
|
||||||
@@ -60,22 +41,17 @@ class LLMRequestQueue:
|
|||||||
if sleep_needed > 0:
|
if sleep_needed > 0:
|
||||||
await asyncio.sleep(sleep_needed)
|
await asyncio.sleep(sleep_needed)
|
||||||
|
|
||||||
async for chunk in self._reliable_stream_request(completion_args):
|
async for chunk in self._stream_request(completion_args):
|
||||||
yield chunk
|
yield chunk
|
||||||
finally:
|
finally:
|
||||||
self._semaphore.release()
|
self._semaphore.release()
|
||||||
|
|
||||||
@retry( # type: ignore[misc]
|
async def _stream_request(
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=8, min=8, max=64),
|
|
||||||
retry=retry_if_exception(should_retry_exception),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
async def _reliable_stream_request(
|
|
||||||
self, completion_args: dict[str, Any]
|
self, completion_args: dict[str, Any]
|
||||||
) -> AsyncIterator[ModelResponseStream]:
|
) -> AsyncIterator[ModelResponseStream]:
|
||||||
response = await asyncio.to_thread(completion, **completion_args, stream=True)
|
response = await acompletion(**completion_args, stream=True)
|
||||||
for chunk in response:
|
|
||||||
|
async for chunk in response:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user