From 9a9a7268cded0227024487f02c564e7c12f56feb Mon Sep 17 00:00:00 2001 From: Ahmed Allam Date: Wed, 10 Sep 2025 15:39:01 -0700 Subject: [PATCH] Better handling of LLM request failures --- pyproject.toml | 2 +- strix/agents/base_agent.py | 41 +++++++++++++++---- strix/agents/state.py | 5 ++- strix/cli/app.py | 13 ++++-- strix/cli/tool_components/browser_renderer.py | 13 ++++++ .../cli/tool_components/file_edit_renderer.py | 4 ++ strix/llm/__init__.py | 5 ++- strix/llm/llm.py | 17 ++++---- strix/llm/memory_compressor.py | 5 ++- strix/llm/request_queue.py | 4 +- 10 files changed, 84 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ca3d7af..55b1ff0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "strix-agent" -version = "0.1.12" +version = "0.1.14" description = "Open-source AI Hackers for your apps" authors = ["Strix "] readme = "README.md" diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index e769788..6977786 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -13,7 +13,7 @@ from jinja2 import ( select_autoescape, ) -from strix.llm import LLM, LLMConfig +from strix.llm import LLM, LLMConfig, LLMRequestFailedError from strix.llm.utils import clean_content from strix.tools import process_tool_invocations @@ -164,6 +164,10 @@ class BaseAgent(metaclass=AgentMeta): await self._enter_waiting_state(tracer) continue + if self.state.llm_failed: + await self._wait_for_input() + continue + self.state.increment_iteration() try: @@ -176,6 +180,13 @@ class BaseAgent(metaclass=AgentMeta): await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True) continue + except LLMRequestFailedError as e: + self.state.add_error(f"LLM request failed: {e}") + self.state.enter_waiting_state(llm_failed=True) + if tracer: + tracer.update_agent_status(self.state.agent_id, "llm_failed") + continue + except (RuntimeError, ValueError, TypeError) as e: if not await self._handle_iteration_error(e, tracer): await self._enter_waiting_state(tracer, error_occurred=True) @@ -327,7 +338,7 @@ class BaseAgent(metaclass=AgentMeta): tracer.update_agent_status(self.state.agent_id, "error") return True - def _check_agent_messages(self, state: AgentState) -> None: + def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912 try: from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages @@ -340,13 +351,29 @@ class BaseAgent(metaclass=AgentMeta): has_new_messages = False for message in messages: if not message.get("read", False): - if state.is_waiting_for_input(): - state.resume_from_waiting() - has_new_messages = True - - sender_name = "Unknown Agent" sender_id = message.get("from") + if state.is_waiting_for_input(): + if state.llm_failed: + if sender_id == "user": + state.resume_from_waiting() + has_new_messages = True + + from strix.cli.tracer import get_global_tracer + + tracer = get_global_tracer() + if tracer: + tracer.update_agent_status(state.agent_id, "running") + else: + state.resume_from_waiting() + has_new_messages = True + + from strix.cli.tracer import get_global_tracer + + tracer = get_global_tracer() + if tracer: + tracer.update_agent_status(state.agent_id, "running") + if sender_id == "user": sender_name = "User" state.add_message("user", message.get("content", "")) diff --git a/strix/agents/state.py b/strix/agents/state.py index 234a0e2..acfa0a5 100644 --- a/strix/agents/state.py +++ b/strix/agents/state.py @@ -23,6 +23,7 @@ class AgentState(BaseModel): completed: bool = False stop_requested: bool = False waiting_for_input: bool = False + llm_failed: bool = False final_result: dict[str, Any] | None = None messages: list[dict[str, Any]] = Field(default_factory=list) @@ -85,15 +86,17 @@ class AgentState(BaseModel): def is_waiting_for_input(self) -> bool: return self.waiting_for_input - def enter_waiting_state(self) -> None: + def enter_waiting_state(self, llm_failed: bool = False) -> None: self.waiting_for_input = True self.stop_requested = False + self.llm_failed = llm_failed self.last_updated = datetime.now(UTC).isoformat() def resume_from_waiting(self, new_task: str | None = None) -> None: self.waiting_for_input = False self.stop_requested = False self.completed = False + self.llm_failed = False if new_task: self.task = new_task self.last_updated = datetime.now(UTC).isoformat() diff --git a/strix/cli/app.py b/strix/cli/app.py index b1c2d8b..e7a29f2 100644 --- a/strix/cli/app.py +++ b/strix/cli/app.py @@ -420,6 +420,7 @@ class StrixCLIApp(App): # type: ignore[misc] "failed": "❌", "stopped": "⏹️", "stopping": "⏸️", + "llm_failed": "🔴", } status_icon = status_indicators.get(status, "🔵") @@ -544,6 +545,12 @@ class StrixCLIApp(App): # type: ignore[misc] self._safe_widget_operation(status_text.update, "Agent completed") self._safe_widget_operation(keymap_indicator.update, "") self._safe_widget_operation(status_display.remove_class, "hidden") + elif status == "llm_failed": + self._safe_widget_operation(status_text.update, "[red]LLM request failed[/red]") + self._safe_widget_operation( + keymap_indicator.update, "[dim]Send message to retry[/dim]" + ) + self._safe_widget_operation(status_display.remove_class, "hidden") elif status == "waiting": animated_text = self._get_animated_waiting_text(self.selected_agent_id) self._safe_widget_operation(status_text.update, animated_text) @@ -626,7 +633,7 @@ class StrixCLIApp(App): # type: ignore[misc] for agent_id, agent_data in self.tracer.agents.items(): status = agent_data.get("status", "running") - if status in ["running", "waiting"]: + if status in ["running", "waiting", "llm_failed"]: has_active_agents = True current_dots = self._agent_dot_states.get(agent_id, 0) self._agent_dot_states[agent_id] = (current_dots + 1) % 4 @@ -637,7 +644,7 @@ class StrixCLIApp(App): # type: ignore[misc] and self.selected_agent_id in self.tracer.agents ): selected_status = self.tracer.agents[self.selected_agent_id].get("status", "running") - if selected_status in ["running", "waiting"]: + if selected_status in ["running", "waiting", "llm_failed"]: self._update_agent_status_display() if not has_active_agents: @@ -645,7 +652,7 @@ class StrixCLIApp(App): # type: ignore[misc] for agent_id in list(self._agent_dot_states.keys()): if agent_id not in self.tracer.agents or self.tracer.agents[agent_id].get( "status" - ) not in ["running", "waiting"]: + ) not in ["running", "waiting", "llm_failed"]: del self._agent_dot_states[agent_id] def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]: diff --git a/strix/cli/tool_components/browser_renderer.py b/strix/cli/tool_components/browser_renderer.py index e1b6afb..1c4723c 100644 --- a/strix/cli/tool_components/browser_renderer.py +++ b/strix/cli/tool_components/browser_renderer.py @@ -30,6 +30,8 @@ class BrowserRenderer(BaseToolRenderer): url = args.get("url") text = args.get("text") js_code = args.get("js_code") + key = args.get("key") + file_path = args.get("file_path") if action in [ "launch", @@ -40,6 +42,8 @@ class BrowserRenderer(BaseToolRenderer): "click", "double_click", "hover", + "press_key", + "save_pdf", ]: if action == "launch": display_url = cls._format_url(url) if url else None @@ -60,6 +64,12 @@ class BrowserRenderer(BaseToolRenderer): message = ( f"executing javascript\n{display_js}" if display_js else "executing javascript" ) + elif action == "press_key": + display_key = cls.escape_markup(key) if key else None + message = f"pressing key {display_key}" if display_key else "pressing key" + elif action == "save_pdf": + display_path = cls.escape_markup(file_path) if file_path else None + message = f"saving PDF to {display_path}" if display_path else "saving PDF" else: action_words = { "click": "clicking", @@ -73,11 +83,14 @@ class BrowserRenderer(BaseToolRenderer): simple_actions = { "back": "going back in browser history", "forward": "going forward in browser history", + "scroll_down": "scrolling down", + "scroll_up": "scrolling up", "refresh": "refreshing browser tab", "close_tab": "closing browser tab", "switch_tab": "switching browser tab", "list_tabs": "listing browser tabs", "view_source": "viewing page source", + "get_console_logs": "getting console logs", "screenshot": "taking screenshot of browser tab", "wait": "waiting...", "close": "closing browser", diff --git a/strix/cli/tool_components/file_edit_renderer.py b/strix/cli/tool_components/file_edit_renderer.py index bd6d5d1..9e2fbe3 100644 --- a/strix/cli/tool_components/file_edit_renderer.py +++ b/strix/cli/tool_components/file_edit_renderer.py @@ -25,6 +25,10 @@ class StrReplaceEditorRenderer(BaseToolRenderer): header = "✏️ [bold #10b981]Editing file[/]" elif command == "create": header = "📝 [bold #10b981]Creating file[/]" + elif command == "insert": + header = "✏️ [bold #10b981]Inserting text[/]" + elif command == "undo_edit": + header = "↩️ [bold #10b981]Undoing edit[/]" else: header = "📄 [bold #10b981]File operation[/]" diff --git a/strix/llm/__init__.py b/strix/llm/__init__.py index bc23f40..6dde525 100644 --- a/strix/llm/__init__.py +++ b/strix/llm/__init__.py @@ -1,12 +1,15 @@ import litellm from .config import LLMConfig -from .llm import LLM +from .llm import LLM, LLMRequestFailedError __all__ = [ "LLM", "LLMConfig", + "LLMRequestFailedError", ] +litellm._logging._disable_debugging() + litellm.drop_params = True diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 5b1cc7a..3749617 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -28,6 +28,11 @@ api_key = os.getenv("LLM_API_KEY") if api_key: litellm.api_key = api_key + +class LLMRequestFailedError(Exception): + """Raised when LLM request fails after all retry attempts.""" + + MODELS_WITHOUT_STOP_WORDS = [ "gpt-5", "gpt-5-mini", @@ -250,15 +255,8 @@ class LLM: tool_invocations=tool_invocations if tool_invocations else None, ) - except (ValueError, TypeError, RuntimeError): - logger.exception("Error in LLM generation") - return LLMResponse( - scan_id=scan_id, - step_number=step_number, - role=StepRole.AGENT, - content="An error occurred while generating the response", - tool_invocations=None, - ) + except Exception as e: + raise LLMRequestFailedError("LLM request failed after all retry attempts") from e @property def usage_stats(self) -> dict[str, dict[str, int | float]]: @@ -307,6 +305,7 @@ class LLM: "model": self.config.model_name, "messages": messages, "temperature": self.config.temperature, + "timeout": 180, } if self._should_include_stop_param(): diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index be0c785..38dbcf6 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -106,10 +106,13 @@ def _summarize_messages( completion_args = { "model": model, "messages": [{"role": "user", "content": prompt}], + "timeout": 180, } response = litellm.completion(**completion_args) - summary = response.choices[0].message.content + summary = response.choices[0].message.content or "" + if not summary.strip(): + return messages[0] summary_msg = "{text}" return { "role": "assistant", diff --git a/strix/llm/request_queue.py b/strix/llm/request_queue.py index 3ea7761..bde0ee1 100644 --- a/strix/llm/request_queue.py +++ b/strix/llm/request_queue.py @@ -38,8 +38,8 @@ class LLMRequestQueue: self._semaphore.release() @retry( # type: ignore[misc] - stop=stop_after_attempt(15), - wait=wait_exponential(multiplier=1.2, min=1, max=300), + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=2, min=1, max=30), reraise=True, ) async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse: