Better handling of LLM request failures

This commit is contained in:
Ahmed Allam
2025-09-10 15:39:01 -07:00
parent 914b981072
commit 9a9a7268cd
10 changed files with 84 additions and 25 deletions

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "strix-agent"
version = "0.1.12"
version = "0.1.14"
description = "Open-source AI Hackers for your apps"
authors = ["Strix <hi@usestrix.com>"]
readme = "README.md"

View File

@@ -13,7 +13,7 @@ from jinja2 import (
select_autoescape,
)
from strix.llm import LLM, LLMConfig
from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content
from strix.tools import process_tool_invocations
@@ -164,6 +164,10 @@ class BaseAgent(metaclass=AgentMeta):
await self._enter_waiting_state(tracer)
continue
if self.state.llm_failed:
await self._wait_for_input()
continue
self.state.increment_iteration()
try:
@@ -176,6 +180,13 @@ class BaseAgent(metaclass=AgentMeta):
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
continue
except LLMRequestFailedError as e:
self.state.add_error(f"LLM request failed: {e}")
self.state.enter_waiting_state(llm_failed=True)
if tracer:
tracer.update_agent_status(self.state.agent_id, "llm_failed")
continue
except (RuntimeError, ValueError, TypeError) as e:
if not await self._handle_iteration_error(e, tracer):
await self._enter_waiting_state(tracer, error_occurred=True)
@@ -327,7 +338,7 @@ class BaseAgent(metaclass=AgentMeta):
tracer.update_agent_status(self.state.agent_id, "error")
return True
def _check_agent_messages(self, state: AgentState) -> None:
def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
try:
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
@@ -340,13 +351,29 @@ class BaseAgent(metaclass=AgentMeta):
has_new_messages = False
for message in messages:
if not message.get("read", False):
if state.is_waiting_for_input():
state.resume_from_waiting()
has_new_messages = True
sender_name = "Unknown Agent"
sender_id = message.get("from")
if state.is_waiting_for_input():
if state.llm_failed:
if sender_id == "user":
state.resume_from_waiting()
has_new_messages = True
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(state.agent_id, "running")
else:
state.resume_from_waiting()
has_new_messages = True
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(state.agent_id, "running")
if sender_id == "user":
sender_name = "User"
state.add_message("user", message.get("content", ""))

View File

@@ -23,6 +23,7 @@ class AgentState(BaseModel):
completed: bool = False
stop_requested: bool = False
waiting_for_input: bool = False
llm_failed: bool = False
final_result: dict[str, Any] | None = None
messages: list[dict[str, Any]] = Field(default_factory=list)
@@ -85,15 +86,17 @@ class AgentState(BaseModel):
def is_waiting_for_input(self) -> bool:
return self.waiting_for_input
def enter_waiting_state(self) -> None:
def enter_waiting_state(self, llm_failed: bool = False) -> None:
self.waiting_for_input = True
self.stop_requested = False
self.llm_failed = llm_failed
self.last_updated = datetime.now(UTC).isoformat()
def resume_from_waiting(self, new_task: str | None = None) -> None:
self.waiting_for_input = False
self.stop_requested = False
self.completed = False
self.llm_failed = False
if new_task:
self.task = new_task
self.last_updated = datetime.now(UTC).isoformat()

View File

@@ -420,6 +420,7 @@ class StrixCLIApp(App): # type: ignore[misc]
"failed": "",
"stopped": "⏹️",
"stopping": "⏸️",
"llm_failed": "🔴",
}
status_icon = status_indicators.get(status, "🔵")
@@ -544,6 +545,12 @@ class StrixCLIApp(App): # type: ignore[misc]
self._safe_widget_operation(status_text.update, "Agent completed")
self._safe_widget_operation(keymap_indicator.update, "")
self._safe_widget_operation(status_display.remove_class, "hidden")
elif status == "llm_failed":
self._safe_widget_operation(status_text.update, "[red]LLM request failed[/red]")
self._safe_widget_operation(
keymap_indicator.update, "[dim]Send message to retry[/dim]"
)
self._safe_widget_operation(status_display.remove_class, "hidden")
elif status == "waiting":
animated_text = self._get_animated_waiting_text(self.selected_agent_id)
self._safe_widget_operation(status_text.update, animated_text)
@@ -626,7 +633,7 @@ class StrixCLIApp(App): # type: ignore[misc]
for agent_id, agent_data in self.tracer.agents.items():
status = agent_data.get("status", "running")
if status in ["running", "waiting"]:
if status in ["running", "waiting", "llm_failed"]:
has_active_agents = True
current_dots = self._agent_dot_states.get(agent_id, 0)
self._agent_dot_states[agent_id] = (current_dots + 1) % 4
@@ -637,7 +644,7 @@ class StrixCLIApp(App): # type: ignore[misc]
and self.selected_agent_id in self.tracer.agents
):
selected_status = self.tracer.agents[self.selected_agent_id].get("status", "running")
if selected_status in ["running", "waiting"]:
if selected_status in ["running", "waiting", "llm_failed"]:
self._update_agent_status_display()
if not has_active_agents:
@@ -645,7 +652,7 @@ class StrixCLIApp(App): # type: ignore[misc]
for agent_id in list(self._agent_dot_states.keys()):
if agent_id not in self.tracer.agents or self.tracer.agents[agent_id].get(
"status"
) not in ["running", "waiting"]:
) not in ["running", "waiting", "llm_failed"]:
del self._agent_dot_states[agent_id]
def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]:

View File

@@ -30,6 +30,8 @@ class BrowserRenderer(BaseToolRenderer):
url = args.get("url")
text = args.get("text")
js_code = args.get("js_code")
key = args.get("key")
file_path = args.get("file_path")
if action in [
"launch",
@@ -40,6 +42,8 @@ class BrowserRenderer(BaseToolRenderer):
"click",
"double_click",
"hover",
"press_key",
"save_pdf",
]:
if action == "launch":
display_url = cls._format_url(url) if url else None
@@ -60,6 +64,12 @@ class BrowserRenderer(BaseToolRenderer):
message = (
f"executing javascript\n{display_js}" if display_js else "executing javascript"
)
elif action == "press_key":
display_key = cls.escape_markup(key) if key else None
message = f"pressing key {display_key}" if display_key else "pressing key"
elif action == "save_pdf":
display_path = cls.escape_markup(file_path) if file_path else None
message = f"saving PDF to {display_path}" if display_path else "saving PDF"
else:
action_words = {
"click": "clicking",
@@ -73,11 +83,14 @@ class BrowserRenderer(BaseToolRenderer):
simple_actions = {
"back": "going back in browser history",
"forward": "going forward in browser history",
"scroll_down": "scrolling down",
"scroll_up": "scrolling up",
"refresh": "refreshing browser tab",
"close_tab": "closing browser tab",
"switch_tab": "switching browser tab",
"list_tabs": "listing browser tabs",
"view_source": "viewing page source",
"get_console_logs": "getting console logs",
"screenshot": "taking screenshot of browser tab",
"wait": "waiting...",
"close": "closing browser",

View File

@@ -25,6 +25,10 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
header = "✏️ [bold #10b981]Editing file[/]"
elif command == "create":
header = "📝 [bold #10b981]Creating file[/]"
elif command == "insert":
header = "✏️ [bold #10b981]Inserting text[/]"
elif command == "undo_edit":
header = "↩️ [bold #10b981]Undoing edit[/]"
else:
header = "📄 [bold #10b981]File operation[/]"

View File

@@ -1,12 +1,15 @@
import litellm
from .config import LLMConfig
from .llm import LLM
from .llm import LLM, LLMRequestFailedError
__all__ = [
"LLM",
"LLMConfig",
"LLMRequestFailedError",
]
litellm._logging._disable_debugging()
litellm.drop_params = True

View File

@@ -28,6 +28,11 @@ api_key = os.getenv("LLM_API_KEY")
if api_key:
litellm.api_key = api_key
class LLMRequestFailedError(Exception):
"""Raised when LLM request fails after all retry attempts."""
MODELS_WITHOUT_STOP_WORDS = [
"gpt-5",
"gpt-5-mini",
@@ -250,15 +255,8 @@ class LLM:
tool_invocations=tool_invocations if tool_invocations else None,
)
except (ValueError, TypeError, RuntimeError):
logger.exception("Error in LLM generation")
return LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content="An error occurred while generating the response",
tool_invocations=None,
)
except Exception as e:
raise LLMRequestFailedError("LLM request failed after all retry attempts") from e
@property
def usage_stats(self) -> dict[str, dict[str, int | float]]:
@@ -307,6 +305,7 @@ class LLM:
"model": self.config.model_name,
"messages": messages,
"temperature": self.config.temperature,
"timeout": 180,
}
if self._should_include_stop_param():

View File

@@ -106,10 +106,13 @@ def _summarize_messages(
completion_args = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"timeout": 180,
}
response = litellm.completion(**completion_args)
summary = response.choices[0].message.content
summary = response.choices[0].message.content or ""
if not summary.strip():
return messages[0]
summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
return {
"role": "assistant",

View File

@@ -38,8 +38,8 @@ class LLMRequestQueue:
self._semaphore.release()
@retry( # type: ignore[misc]
stop=stop_after_attempt(15),
wait=wait_exponential(multiplier=1.2, min=1, max=300),
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=30),
reraise=True,
)
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse: