Better handling of LLM request failures

This commit is contained in:
Ahmed Allam
2025-09-10 15:39:01 -07:00
parent 914b981072
commit 9a9a7268cd
10 changed files with 84 additions and 25 deletions

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "strix-agent" name = "strix-agent"
version = "0.1.12" version = "0.1.14"
description = "Open-source AI Hackers for your apps" description = "Open-source AI Hackers for your apps"
authors = ["Strix <hi@usestrix.com>"] authors = ["Strix <hi@usestrix.com>"]
readme = "README.md" readme = "README.md"

View File

@@ -13,7 +13,7 @@ from jinja2 import (
select_autoescape, select_autoescape,
) )
from strix.llm import LLM, LLMConfig from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content from strix.llm.utils import clean_content
from strix.tools import process_tool_invocations from strix.tools import process_tool_invocations
@@ -164,6 +164,10 @@ class BaseAgent(metaclass=AgentMeta):
await self._enter_waiting_state(tracer) await self._enter_waiting_state(tracer)
continue continue
if self.state.llm_failed:
await self._wait_for_input()
continue
self.state.increment_iteration() self.state.increment_iteration()
try: try:
@@ -176,6 +180,13 @@ class BaseAgent(metaclass=AgentMeta):
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True) await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
continue continue
except LLMRequestFailedError as e:
self.state.add_error(f"LLM request failed: {e}")
self.state.enter_waiting_state(llm_failed=True)
if tracer:
tracer.update_agent_status(self.state.agent_id, "llm_failed")
continue
except (RuntimeError, ValueError, TypeError) as e: except (RuntimeError, ValueError, TypeError) as e:
if not await self._handle_iteration_error(e, tracer): if not await self._handle_iteration_error(e, tracer):
await self._enter_waiting_state(tracer, error_occurred=True) await self._enter_waiting_state(tracer, error_occurred=True)
@@ -327,7 +338,7 @@ class BaseAgent(metaclass=AgentMeta):
tracer.update_agent_status(self.state.agent_id, "error") tracer.update_agent_status(self.state.agent_id, "error")
return True return True
def _check_agent_messages(self, state: AgentState) -> None: def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
try: try:
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
@@ -340,13 +351,29 @@ class BaseAgent(metaclass=AgentMeta):
has_new_messages = False has_new_messages = False
for message in messages: for message in messages:
if not message.get("read", False): if not message.get("read", False):
if state.is_waiting_for_input():
state.resume_from_waiting()
has_new_messages = True
sender_name = "Unknown Agent"
sender_id = message.get("from") sender_id = message.get("from")
if state.is_waiting_for_input():
if state.llm_failed:
if sender_id == "user":
state.resume_from_waiting()
has_new_messages = True
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(state.agent_id, "running")
else:
state.resume_from_waiting()
has_new_messages = True
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(state.agent_id, "running")
if sender_id == "user": if sender_id == "user":
sender_name = "User" sender_name = "User"
state.add_message("user", message.get("content", "")) state.add_message("user", message.get("content", ""))

View File

@@ -23,6 +23,7 @@ class AgentState(BaseModel):
completed: bool = False completed: bool = False
stop_requested: bool = False stop_requested: bool = False
waiting_for_input: bool = False waiting_for_input: bool = False
llm_failed: bool = False
final_result: dict[str, Any] | None = None final_result: dict[str, Any] | None = None
messages: list[dict[str, Any]] = Field(default_factory=list) messages: list[dict[str, Any]] = Field(default_factory=list)
@@ -85,15 +86,17 @@ class AgentState(BaseModel):
def is_waiting_for_input(self) -> bool: def is_waiting_for_input(self) -> bool:
return self.waiting_for_input return self.waiting_for_input
def enter_waiting_state(self) -> None: def enter_waiting_state(self, llm_failed: bool = False) -> None:
self.waiting_for_input = True self.waiting_for_input = True
self.stop_requested = False self.stop_requested = False
self.llm_failed = llm_failed
self.last_updated = datetime.now(UTC).isoformat() self.last_updated = datetime.now(UTC).isoformat()
def resume_from_waiting(self, new_task: str | None = None) -> None: def resume_from_waiting(self, new_task: str | None = None) -> None:
self.waiting_for_input = False self.waiting_for_input = False
self.stop_requested = False self.stop_requested = False
self.completed = False self.completed = False
self.llm_failed = False
if new_task: if new_task:
self.task = new_task self.task = new_task
self.last_updated = datetime.now(UTC).isoformat() self.last_updated = datetime.now(UTC).isoformat()

View File

@@ -420,6 +420,7 @@ class StrixCLIApp(App): # type: ignore[misc]
"failed": "", "failed": "",
"stopped": "⏹️", "stopped": "⏹️",
"stopping": "⏸️", "stopping": "⏸️",
"llm_failed": "🔴",
} }
status_icon = status_indicators.get(status, "🔵") status_icon = status_indicators.get(status, "🔵")
@@ -544,6 +545,12 @@ class StrixCLIApp(App): # type: ignore[misc]
self._safe_widget_operation(status_text.update, "Agent completed") self._safe_widget_operation(status_text.update, "Agent completed")
self._safe_widget_operation(keymap_indicator.update, "") self._safe_widget_operation(keymap_indicator.update, "")
self._safe_widget_operation(status_display.remove_class, "hidden") self._safe_widget_operation(status_display.remove_class, "hidden")
elif status == "llm_failed":
self._safe_widget_operation(status_text.update, "[red]LLM request failed[/red]")
self._safe_widget_operation(
keymap_indicator.update, "[dim]Send message to retry[/dim]"
)
self._safe_widget_operation(status_display.remove_class, "hidden")
elif status == "waiting": elif status == "waiting":
animated_text = self._get_animated_waiting_text(self.selected_agent_id) animated_text = self._get_animated_waiting_text(self.selected_agent_id)
self._safe_widget_operation(status_text.update, animated_text) self._safe_widget_operation(status_text.update, animated_text)
@@ -626,7 +633,7 @@ class StrixCLIApp(App): # type: ignore[misc]
for agent_id, agent_data in self.tracer.agents.items(): for agent_id, agent_data in self.tracer.agents.items():
status = agent_data.get("status", "running") status = agent_data.get("status", "running")
if status in ["running", "waiting"]: if status in ["running", "waiting", "llm_failed"]:
has_active_agents = True has_active_agents = True
current_dots = self._agent_dot_states.get(agent_id, 0) current_dots = self._agent_dot_states.get(agent_id, 0)
self._agent_dot_states[agent_id] = (current_dots + 1) % 4 self._agent_dot_states[agent_id] = (current_dots + 1) % 4
@@ -637,7 +644,7 @@ class StrixCLIApp(App): # type: ignore[misc]
and self.selected_agent_id in self.tracer.agents and self.selected_agent_id in self.tracer.agents
): ):
selected_status = self.tracer.agents[self.selected_agent_id].get("status", "running") selected_status = self.tracer.agents[self.selected_agent_id].get("status", "running")
if selected_status in ["running", "waiting"]: if selected_status in ["running", "waiting", "llm_failed"]:
self._update_agent_status_display() self._update_agent_status_display()
if not has_active_agents: if not has_active_agents:
@@ -645,7 +652,7 @@ class StrixCLIApp(App): # type: ignore[misc]
for agent_id in list(self._agent_dot_states.keys()): for agent_id in list(self._agent_dot_states.keys()):
if agent_id not in self.tracer.agents or self.tracer.agents[agent_id].get( if agent_id not in self.tracer.agents or self.tracer.agents[agent_id].get(
"status" "status"
) not in ["running", "waiting"]: ) not in ["running", "waiting", "llm_failed"]:
del self._agent_dot_states[agent_id] del self._agent_dot_states[agent_id]
def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]: def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]:

View File

@@ -30,6 +30,8 @@ class BrowserRenderer(BaseToolRenderer):
url = args.get("url") url = args.get("url")
text = args.get("text") text = args.get("text")
js_code = args.get("js_code") js_code = args.get("js_code")
key = args.get("key")
file_path = args.get("file_path")
if action in [ if action in [
"launch", "launch",
@@ -40,6 +42,8 @@ class BrowserRenderer(BaseToolRenderer):
"click", "click",
"double_click", "double_click",
"hover", "hover",
"press_key",
"save_pdf",
]: ]:
if action == "launch": if action == "launch":
display_url = cls._format_url(url) if url else None display_url = cls._format_url(url) if url else None
@@ -60,6 +64,12 @@ class BrowserRenderer(BaseToolRenderer):
message = ( message = (
f"executing javascript\n{display_js}" if display_js else "executing javascript" f"executing javascript\n{display_js}" if display_js else "executing javascript"
) )
elif action == "press_key":
display_key = cls.escape_markup(key) if key else None
message = f"pressing key {display_key}" if display_key else "pressing key"
elif action == "save_pdf":
display_path = cls.escape_markup(file_path) if file_path else None
message = f"saving PDF to {display_path}" if display_path else "saving PDF"
else: else:
action_words = { action_words = {
"click": "clicking", "click": "clicking",
@@ -73,11 +83,14 @@ class BrowserRenderer(BaseToolRenderer):
simple_actions = { simple_actions = {
"back": "going back in browser history", "back": "going back in browser history",
"forward": "going forward in browser history", "forward": "going forward in browser history",
"scroll_down": "scrolling down",
"scroll_up": "scrolling up",
"refresh": "refreshing browser tab", "refresh": "refreshing browser tab",
"close_tab": "closing browser tab", "close_tab": "closing browser tab",
"switch_tab": "switching browser tab", "switch_tab": "switching browser tab",
"list_tabs": "listing browser tabs", "list_tabs": "listing browser tabs",
"view_source": "viewing page source", "view_source": "viewing page source",
"get_console_logs": "getting console logs",
"screenshot": "taking screenshot of browser tab", "screenshot": "taking screenshot of browser tab",
"wait": "waiting...", "wait": "waiting...",
"close": "closing browser", "close": "closing browser",

View File

@@ -25,6 +25,10 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
header = "✏️ [bold #10b981]Editing file[/]" header = "✏️ [bold #10b981]Editing file[/]"
elif command == "create": elif command == "create":
header = "📝 [bold #10b981]Creating file[/]" header = "📝 [bold #10b981]Creating file[/]"
elif command == "insert":
header = "✏️ [bold #10b981]Inserting text[/]"
elif command == "undo_edit":
header = "↩️ [bold #10b981]Undoing edit[/]"
else: else:
header = "📄 [bold #10b981]File operation[/]" header = "📄 [bold #10b981]File operation[/]"

View File

@@ -1,12 +1,15 @@
import litellm import litellm
from .config import LLMConfig from .config import LLMConfig
from .llm import LLM from .llm import LLM, LLMRequestFailedError
__all__ = [ __all__ = [
"LLM", "LLM",
"LLMConfig", "LLMConfig",
"LLMRequestFailedError",
] ]
litellm._logging._disable_debugging()
litellm.drop_params = True litellm.drop_params = True

View File

@@ -28,6 +28,11 @@ api_key = os.getenv("LLM_API_KEY")
if api_key: if api_key:
litellm.api_key = api_key litellm.api_key = api_key
class LLMRequestFailedError(Exception):
"""Raised when LLM request fails after all retry attempts."""
MODELS_WITHOUT_STOP_WORDS = [ MODELS_WITHOUT_STOP_WORDS = [
"gpt-5", "gpt-5",
"gpt-5-mini", "gpt-5-mini",
@@ -250,15 +255,8 @@ class LLM:
tool_invocations=tool_invocations if tool_invocations else None, tool_invocations=tool_invocations if tool_invocations else None,
) )
except (ValueError, TypeError, RuntimeError): except Exception as e:
logger.exception("Error in LLM generation") raise LLMRequestFailedError("LLM request failed after all retry attempts") from e
return LLMResponse(
scan_id=scan_id,
step_number=step_number,
role=StepRole.AGENT,
content="An error occurred while generating the response",
tool_invocations=None,
)
@property @property
def usage_stats(self) -> dict[str, dict[str, int | float]]: def usage_stats(self) -> dict[str, dict[str, int | float]]:
@@ -307,6 +305,7 @@ class LLM:
"model": self.config.model_name, "model": self.config.model_name,
"messages": messages, "messages": messages,
"temperature": self.config.temperature, "temperature": self.config.temperature,
"timeout": 180,
} }
if self._should_include_stop_param(): if self._should_include_stop_param():

View File

@@ -106,10 +106,13 @@ def _summarize_messages(
completion_args = { completion_args = {
"model": model, "model": model,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"timeout": 180,
} }
response = litellm.completion(**completion_args) response = litellm.completion(**completion_args)
summary = response.choices[0].message.content summary = response.choices[0].message.content or ""
if not summary.strip():
return messages[0]
summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>" summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
return { return {
"role": "assistant", "role": "assistant",

View File

@@ -38,8 +38,8 @@ class LLMRequestQueue:
self._semaphore.release() self._semaphore.release()
@retry( # type: ignore[misc] @retry( # type: ignore[misc]
stop=stop_after_attempt(15), stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1.2, min=1, max=300), wait=wait_exponential(multiplier=2, min=1, max=30),
reraise=True, reraise=True,
) )
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse: async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse: