Better handling of LLM request failures
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "strix-agent"
|
||||
version = "0.1.12"
|
||||
version = "0.1.14"
|
||||
description = "Open-source AI Hackers for your apps"
|
||||
authors = ["Strix <hi@usestrix.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
@@ -13,7 +13,7 @@ from jinja2 import (
|
||||
select_autoescape,
|
||||
)
|
||||
|
||||
from strix.llm import LLM, LLMConfig
|
||||
from strix.llm import LLM, LLMConfig, LLMRequestFailedError
|
||||
from strix.llm.utils import clean_content
|
||||
from strix.tools import process_tool_invocations
|
||||
|
||||
@@ -164,6 +164,10 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
await self._enter_waiting_state(tracer)
|
||||
continue
|
||||
|
||||
if self.state.llm_failed:
|
||||
await self._wait_for_input()
|
||||
continue
|
||||
|
||||
self.state.increment_iteration()
|
||||
|
||||
try:
|
||||
@@ -176,6 +180,13 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
|
||||
continue
|
||||
|
||||
except LLMRequestFailedError as e:
|
||||
self.state.add_error(f"LLM request failed: {e}")
|
||||
self.state.enter_waiting_state(llm_failed=True)
|
||||
if tracer:
|
||||
tracer.update_agent_status(self.state.agent_id, "llm_failed")
|
||||
continue
|
||||
|
||||
except (RuntimeError, ValueError, TypeError) as e:
|
||||
if not await self._handle_iteration_error(e, tracer):
|
||||
await self._enter_waiting_state(tracer, error_occurred=True)
|
||||
@@ -327,7 +338,7 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
tracer.update_agent_status(self.state.agent_id, "error")
|
||||
return True
|
||||
|
||||
def _check_agent_messages(self, state: AgentState) -> None:
|
||||
def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
|
||||
try:
|
||||
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
|
||||
|
||||
@@ -340,13 +351,29 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
has_new_messages = False
|
||||
for message in messages:
|
||||
if not message.get("read", False):
|
||||
if state.is_waiting_for_input():
|
||||
state.resume_from_waiting()
|
||||
has_new_messages = True
|
||||
|
||||
sender_name = "Unknown Agent"
|
||||
sender_id = message.get("from")
|
||||
|
||||
if state.is_waiting_for_input():
|
||||
if state.llm_failed:
|
||||
if sender_id == "user":
|
||||
state.resume_from_waiting()
|
||||
has_new_messages = True
|
||||
|
||||
from strix.cli.tracer import get_global_tracer
|
||||
|
||||
tracer = get_global_tracer()
|
||||
if tracer:
|
||||
tracer.update_agent_status(state.agent_id, "running")
|
||||
else:
|
||||
state.resume_from_waiting()
|
||||
has_new_messages = True
|
||||
|
||||
from strix.cli.tracer import get_global_tracer
|
||||
|
||||
tracer = get_global_tracer()
|
||||
if tracer:
|
||||
tracer.update_agent_status(state.agent_id, "running")
|
||||
|
||||
if sender_id == "user":
|
||||
sender_name = "User"
|
||||
state.add_message("user", message.get("content", ""))
|
||||
|
||||
@@ -23,6 +23,7 @@ class AgentState(BaseModel):
|
||||
completed: bool = False
|
||||
stop_requested: bool = False
|
||||
waiting_for_input: bool = False
|
||||
llm_failed: bool = False
|
||||
final_result: dict[str, Any] | None = None
|
||||
|
||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||
@@ -85,15 +86,17 @@ class AgentState(BaseModel):
|
||||
def is_waiting_for_input(self) -> bool:
|
||||
return self.waiting_for_input
|
||||
|
||||
def enter_waiting_state(self) -> None:
|
||||
def enter_waiting_state(self, llm_failed: bool = False) -> None:
|
||||
self.waiting_for_input = True
|
||||
self.stop_requested = False
|
||||
self.llm_failed = llm_failed
|
||||
self.last_updated = datetime.now(UTC).isoformat()
|
||||
|
||||
def resume_from_waiting(self, new_task: str | None = None) -> None:
|
||||
self.waiting_for_input = False
|
||||
self.stop_requested = False
|
||||
self.completed = False
|
||||
self.llm_failed = False
|
||||
if new_task:
|
||||
self.task = new_task
|
||||
self.last_updated = datetime.now(UTC).isoformat()
|
||||
|
||||
@@ -420,6 +420,7 @@ class StrixCLIApp(App): # type: ignore[misc]
|
||||
"failed": "❌",
|
||||
"stopped": "⏹️",
|
||||
"stopping": "⏸️",
|
||||
"llm_failed": "🔴",
|
||||
}
|
||||
|
||||
status_icon = status_indicators.get(status, "🔵")
|
||||
@@ -544,6 +545,12 @@ class StrixCLIApp(App): # type: ignore[misc]
|
||||
self._safe_widget_operation(status_text.update, "Agent completed")
|
||||
self._safe_widget_operation(keymap_indicator.update, "")
|
||||
self._safe_widget_operation(status_display.remove_class, "hidden")
|
||||
elif status == "llm_failed":
|
||||
self._safe_widget_operation(status_text.update, "[red]LLM request failed[/red]")
|
||||
self._safe_widget_operation(
|
||||
keymap_indicator.update, "[dim]Send message to retry[/dim]"
|
||||
)
|
||||
self._safe_widget_operation(status_display.remove_class, "hidden")
|
||||
elif status == "waiting":
|
||||
animated_text = self._get_animated_waiting_text(self.selected_agent_id)
|
||||
self._safe_widget_operation(status_text.update, animated_text)
|
||||
@@ -626,7 +633,7 @@ class StrixCLIApp(App): # type: ignore[misc]
|
||||
|
||||
for agent_id, agent_data in self.tracer.agents.items():
|
||||
status = agent_data.get("status", "running")
|
||||
if status in ["running", "waiting"]:
|
||||
if status in ["running", "waiting", "llm_failed"]:
|
||||
has_active_agents = True
|
||||
current_dots = self._agent_dot_states.get(agent_id, 0)
|
||||
self._agent_dot_states[agent_id] = (current_dots + 1) % 4
|
||||
@@ -637,7 +644,7 @@ class StrixCLIApp(App): # type: ignore[misc]
|
||||
and self.selected_agent_id in self.tracer.agents
|
||||
):
|
||||
selected_status = self.tracer.agents[self.selected_agent_id].get("status", "running")
|
||||
if selected_status in ["running", "waiting"]:
|
||||
if selected_status in ["running", "waiting", "llm_failed"]:
|
||||
self._update_agent_status_display()
|
||||
|
||||
if not has_active_agents:
|
||||
@@ -645,7 +652,7 @@ class StrixCLIApp(App): # type: ignore[misc]
|
||||
for agent_id in list(self._agent_dot_states.keys()):
|
||||
if agent_id not in self.tracer.agents or self.tracer.agents[agent_id].get(
|
||||
"status"
|
||||
) not in ["running", "waiting"]:
|
||||
) not in ["running", "waiting", "llm_failed"]:
|
||||
del self._agent_dot_states[agent_id]
|
||||
|
||||
def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]:
|
||||
|
||||
@@ -30,6 +30,8 @@ class BrowserRenderer(BaseToolRenderer):
|
||||
url = args.get("url")
|
||||
text = args.get("text")
|
||||
js_code = args.get("js_code")
|
||||
key = args.get("key")
|
||||
file_path = args.get("file_path")
|
||||
|
||||
if action in [
|
||||
"launch",
|
||||
@@ -40,6 +42,8 @@ class BrowserRenderer(BaseToolRenderer):
|
||||
"click",
|
||||
"double_click",
|
||||
"hover",
|
||||
"press_key",
|
||||
"save_pdf",
|
||||
]:
|
||||
if action == "launch":
|
||||
display_url = cls._format_url(url) if url else None
|
||||
@@ -60,6 +64,12 @@ class BrowserRenderer(BaseToolRenderer):
|
||||
message = (
|
||||
f"executing javascript\n{display_js}" if display_js else "executing javascript"
|
||||
)
|
||||
elif action == "press_key":
|
||||
display_key = cls.escape_markup(key) if key else None
|
||||
message = f"pressing key {display_key}" if display_key else "pressing key"
|
||||
elif action == "save_pdf":
|
||||
display_path = cls.escape_markup(file_path) if file_path else None
|
||||
message = f"saving PDF to {display_path}" if display_path else "saving PDF"
|
||||
else:
|
||||
action_words = {
|
||||
"click": "clicking",
|
||||
@@ -73,11 +83,14 @@ class BrowserRenderer(BaseToolRenderer):
|
||||
simple_actions = {
|
||||
"back": "going back in browser history",
|
||||
"forward": "going forward in browser history",
|
||||
"scroll_down": "scrolling down",
|
||||
"scroll_up": "scrolling up",
|
||||
"refresh": "refreshing browser tab",
|
||||
"close_tab": "closing browser tab",
|
||||
"switch_tab": "switching browser tab",
|
||||
"list_tabs": "listing browser tabs",
|
||||
"view_source": "viewing page source",
|
||||
"get_console_logs": "getting console logs",
|
||||
"screenshot": "taking screenshot of browser tab",
|
||||
"wait": "waiting...",
|
||||
"close": "closing browser",
|
||||
|
||||
@@ -25,6 +25,10 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
|
||||
header = "✏️ [bold #10b981]Editing file[/]"
|
||||
elif command == "create":
|
||||
header = "📝 [bold #10b981]Creating file[/]"
|
||||
elif command == "insert":
|
||||
header = "✏️ [bold #10b981]Inserting text[/]"
|
||||
elif command == "undo_edit":
|
||||
header = "↩️ [bold #10b981]Undoing edit[/]"
|
||||
else:
|
||||
header = "📄 [bold #10b981]File operation[/]"
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import litellm
|
||||
|
||||
from .config import LLMConfig
|
||||
from .llm import LLM
|
||||
from .llm import LLM, LLMRequestFailedError
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
"LLMConfig",
|
||||
"LLMRequestFailedError",
|
||||
]
|
||||
|
||||
litellm._logging._disable_debugging()
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -28,6 +28,11 @@ api_key = os.getenv("LLM_API_KEY")
|
||||
if api_key:
|
||||
litellm.api_key = api_key
|
||||
|
||||
|
||||
class LLMRequestFailedError(Exception):
|
||||
"""Raised when LLM request fails after all retry attempts."""
|
||||
|
||||
|
||||
MODELS_WITHOUT_STOP_WORDS = [
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
@@ -250,15 +255,8 @@ class LLM:
|
||||
tool_invocations=tool_invocations if tool_invocations else None,
|
||||
)
|
||||
|
||||
except (ValueError, TypeError, RuntimeError):
|
||||
logger.exception("Error in LLM generation")
|
||||
return LLMResponse(
|
||||
scan_id=scan_id,
|
||||
step_number=step_number,
|
||||
role=StepRole.AGENT,
|
||||
content="An error occurred while generating the response",
|
||||
tool_invocations=None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise LLMRequestFailedError("LLM request failed after all retry attempts") from e
|
||||
|
||||
@property
|
||||
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
||||
@@ -307,6 +305,7 @@ class LLM:
|
||||
"model": self.config.model_name,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"timeout": 180,
|
||||
}
|
||||
|
||||
if self._should_include_stop_param():
|
||||
|
||||
@@ -106,10 +106,13 @@ def _summarize_messages(
|
||||
completion_args = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"timeout": 180,
|
||||
}
|
||||
|
||||
response = litellm.completion(**completion_args)
|
||||
summary = response.choices[0].message.content
|
||||
summary = response.choices[0].message.content or ""
|
||||
if not summary.strip():
|
||||
return messages[0]
|
||||
summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
|
||||
return {
|
||||
"role": "assistant",
|
||||
|
||||
@@ -38,8 +38,8 @@ class LLMRequestQueue:
|
||||
self._semaphore.release()
|
||||
|
||||
@retry( # type: ignore[misc]
|
||||
stop=stop_after_attempt(15),
|
||||
wait=wait_exponential(multiplier=1.2, min=1, max=300),
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=30),
|
||||
reraise=True,
|
||||
)
|
||||
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
||||
|
||||
Reference in New Issue
Block a user