diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 6977786..0ab3902 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -181,10 +181,21 @@ class BaseAgent(metaclass=AgentMeta): continue except LLMRequestFailedError as e: - self.state.add_error(f"LLM request failed: {e}") + error_msg = str(e) + error_details = getattr(e, "details", None) + self.state.add_error(error_msg) self.state.enter_waiting_state(llm_failed=True) if tracer: - tracer.update_agent_status(self.state.agent_id, "llm_failed") + tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg) + if error_details: + tracer.log_tool_execution_start( + self.state.agent_id, + "llm_error_details", + {"error": error_msg, "details": error_details}, + ) + tracer.update_tool_execution( + tracer._next_execution_id - 1, "failed", error_details + ) continue except (RuntimeError, ValueError, TypeError) as e: diff --git a/strix/cli/app.py b/strix/cli/app.py index e7a29f2..0e059d0 100644 --- a/strix/cli/app.py +++ b/strix/cli/app.py @@ -546,11 +546,16 @@ class StrixCLIApp(App): # type: ignore[misc] self._safe_widget_operation(keymap_indicator.update, "") self._safe_widget_operation(status_display.remove_class, "hidden") elif status == "llm_failed": - self._safe_widget_operation(status_text.update, "[red]LLM request failed[/red]") + error_msg = agent_data.get("error_message", "") + display_msg = ( + f"[red]{error_msg}[/red]" if error_msg else "[red]LLM request failed[/red]" + ) + self._safe_widget_operation(status_text.update, display_msg) self._safe_widget_operation( keymap_indicator.update, "[dim]Send message to retry[/dim]" ) self._safe_widget_operation(status_display.remove_class, "hidden") + self._stop_dot_animation() elif status == "waiting": animated_text = self._get_animated_waiting_text(self.selected_agent_id) self._safe_widget_operation(status_text.update, animated_text) @@ -633,7 +638,7 @@ class StrixCLIApp(App): # type: ignore[misc] for agent_id, agent_data in self.tracer.agents.items(): status = agent_data.get("status", "running") - if status in ["running", "waiting", "llm_failed"]: + if status in ["running", "waiting"]: has_active_agents = True current_dots = self._agent_dot_states.get(agent_id, 0) self._agent_dot_states[agent_id] = (current_dots + 1) % 4 @@ -644,7 +649,7 @@ class StrixCLIApp(App): # type: ignore[misc] and self.selected_agent_id in self.tracer.agents ): selected_status = self.tracer.agents[self.selected_agent_id].get("status", "running") - if selected_status in ["running", "waiting", "llm_failed"]: + if selected_status in ["running", "waiting"]: self._update_agent_status_display() if not has_active_agents: @@ -652,7 +657,7 @@ class StrixCLIApp(App): # type: ignore[misc] for agent_id in list(self._agent_dot_states.keys()): if agent_id not in self.tracer.agents or self.tracer.agents[agent_id].get( "status" - ) not in ["running", "waiting", "llm_failed"]: + ) not in ["running", "waiting"]: del self._agent_dot_states[agent_id] def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]: @@ -900,6 +905,7 @@ class StrixCLIApp(App): # type: ignore[misc] "reporting_action": "#ea580c", "scan_start_info": "#22c55e", "subagent_start_info": "#22c55e", + "llm_error_details": "#dc2626", } color = tool_colors.get(tool_name, "#737373") @@ -911,6 +917,14 @@ class StrixCLIApp(App): # type: ignore[misc] if renderer: widget = renderer.render(tool_data) content = str(widget.renderable) + elif tool_name == "llm_error_details": + lines = ["[red]✗ LLM Request Failed[/red]"] + if args.get("details"): + details = args["details"] + if len(details) > 300: + details = details[:297] + "..." + lines.append(f"[dim]Details:[/dim] {escape_markup(details)}") + content = "\n".join(lines) else: status_icons = { "running": "[yellow]●[/yellow]", diff --git a/strix/cli/tracer.py b/strix/cli/tracer.py index 4167beb..54ab196 100644 --- a/strix/cli/tracer.py +++ b/strix/cli/tracer.py @@ -168,10 +168,14 @@ class Tracer: self.tool_executions[execution_id]["result"] = result self.tool_executions[execution_id]["completed_at"] = datetime.now(UTC).isoformat() - def update_agent_status(self, agent_id: str, status: str) -> None: + def update_agent_status( + self, agent_id: str, status: str, error_message: str | None = None + ) -> None: if agent_id in self.agents: self.agents[agent_id]["status"] = status self.agents[agent_id]["updated_at"] = datetime.now(UTC).isoformat() + if error_message: + self.agents[agent_id]["error_message"] = error_message def set_scan_config(self, config: dict[str, Any]) -> None: self.scan_config = config diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 3749617..f24c509 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -30,7 +30,10 @@ if api_key: class LLMRequestFailedError(Exception): - """Raised when LLM request fails after all retry attempts.""" + def __init__(self, message: str, details: str | None = None): + super().__init__(message) + self.message = message + self.details = details MODELS_WITHOUT_STOP_WORDS = [ @@ -211,7 +214,7 @@ class LLM: return cached_messages - async def generate( + async def generate( # noqa: PLR0912, PLR0915 self, conversation_history: list[dict[str, Any]], scan_id: str | None = None, @@ -255,8 +258,50 @@ class LLM: tool_invocations=tool_invocations if tool_invocations else None, ) + except litellm.RateLimitError as e: + raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e + except litellm.AuthenticationError as e: + raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e + except litellm.NotFoundError as e: + raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e + except litellm.ContextWindowExceededError as e: + raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e + except litellm.ContentPolicyViolationError as e: + raise LLMRequestFailedError( + "LLM request failed: Content policy violation", str(e) + ) from e + except litellm.ServiceUnavailableError as e: + raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e + except litellm.Timeout as e: + raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e + except litellm.UnprocessableEntityError as e: + raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e + except litellm.InternalServerError as e: + raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e + except litellm.APIConnectionError as e: + raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e + except litellm.UnsupportedParamsError as e: + raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e + except litellm.BudgetExceededError as e: + raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e + except litellm.APIResponseValidationError as e: + raise LLMRequestFailedError( + "LLM request failed: Response validation error", str(e) + ) from e + except litellm.JSONSchemaValidationError as e: + raise LLMRequestFailedError( + "LLM request failed: JSON schema validation error", str(e) + ) from e + except litellm.InvalidRequestError as e: + raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e + except litellm.BadRequestError as e: + raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e + except litellm.APIError as e: + raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e + except litellm.OpenAIError as e: + raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e except Exception as e: - raise LLMRequestFailedError("LLM request failed after all retry attempts") from e + raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e @property def usage_stats(self) -> dict[str, dict[str, int | float]]: diff --git a/strix/llm/request_queue.py b/strix/llm/request_queue.py index bde0ee1..cd99bcf 100644 --- a/strix/llm/request_queue.py +++ b/strix/llm/request_queue.py @@ -4,13 +4,27 @@ import threading import time from typing import Any +import litellm from litellm import ModelResponse, completion -from tenacity import retry, stop_after_attempt, wait_exponential +from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential logger = logging.getLogger(__name__) +def should_retry_exception(exception: Exception) -> bool: + status_code = None + + if hasattr(exception, "status_code"): + status_code = exception.status_code + elif hasattr(exception, "response") and hasattr(exception.response, "status_code"): + status_code = exception.response.status_code + + if status_code is not None: + return bool(litellm._should_retry(status_code)) + return True + + class LLMRequestQueue: def __init__(self, max_concurrent: int = 6, delay_between_requests: float = 1.0): self.max_concurrent = max_concurrent @@ -40,6 +54,7 @@ class LLMRequestQueue: @retry( # type: ignore[misc] stop=stop_after_attempt(5), wait=wait_exponential(multiplier=2, min=1, max=30), + retry=retry_if_exception(should_retry_exception), reraise=True, ) async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse: