Better handling of LLM request failures
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "strix-agent"
|
name = "strix-agent"
|
||||||
version = "0.1.12"
|
version = "0.1.14"
|
||||||
description = "Open-source AI Hackers for your apps"
|
description = "Open-source AI Hackers for your apps"
|
||||||
authors = ["Strix <hi@usestrix.com>"]
|
authors = ["Strix <hi@usestrix.com>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from jinja2 import (
|
|||||||
select_autoescape,
|
select_autoescape,
|
||||||
)
|
)
|
||||||
|
|
||||||
from strix.llm import LLM, LLMConfig
|
from strix.llm import LLM, LLMConfig, LLMRequestFailedError
|
||||||
from strix.llm.utils import clean_content
|
from strix.llm.utils import clean_content
|
||||||
from strix.tools import process_tool_invocations
|
from strix.tools import process_tool_invocations
|
||||||
|
|
||||||
@@ -164,6 +164,10 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
await self._enter_waiting_state(tracer)
|
await self._enter_waiting_state(tracer)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if self.state.llm_failed:
|
||||||
|
await self._wait_for_input()
|
||||||
|
continue
|
||||||
|
|
||||||
self.state.increment_iteration()
|
self.state.increment_iteration()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -176,6 +180,13 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
|
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
except LLMRequestFailedError as e:
|
||||||
|
self.state.add_error(f"LLM request failed: {e}")
|
||||||
|
self.state.enter_waiting_state(llm_failed=True)
|
||||||
|
if tracer:
|
||||||
|
tracer.update_agent_status(self.state.agent_id, "llm_failed")
|
||||||
|
continue
|
||||||
|
|
||||||
except (RuntimeError, ValueError, TypeError) as e:
|
except (RuntimeError, ValueError, TypeError) as e:
|
||||||
if not await self._handle_iteration_error(e, tracer):
|
if not await self._handle_iteration_error(e, tracer):
|
||||||
await self._enter_waiting_state(tracer, error_occurred=True)
|
await self._enter_waiting_state(tracer, error_occurred=True)
|
||||||
@@ -327,7 +338,7 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
tracer.update_agent_status(self.state.agent_id, "error")
|
tracer.update_agent_status(self.state.agent_id, "error")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _check_agent_messages(self, state: AgentState) -> None:
|
def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
|
||||||
try:
|
try:
|
||||||
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
|
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
|
||||||
|
|
||||||
@@ -340,13 +351,29 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
has_new_messages = False
|
has_new_messages = False
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if not message.get("read", False):
|
if not message.get("read", False):
|
||||||
if state.is_waiting_for_input():
|
|
||||||
state.resume_from_waiting()
|
|
||||||
has_new_messages = True
|
|
||||||
|
|
||||||
sender_name = "Unknown Agent"
|
|
||||||
sender_id = message.get("from")
|
sender_id = message.get("from")
|
||||||
|
|
||||||
|
if state.is_waiting_for_input():
|
||||||
|
if state.llm_failed:
|
||||||
|
if sender_id == "user":
|
||||||
|
state.resume_from_waiting()
|
||||||
|
has_new_messages = True
|
||||||
|
|
||||||
|
from strix.cli.tracer import get_global_tracer
|
||||||
|
|
||||||
|
tracer = get_global_tracer()
|
||||||
|
if tracer:
|
||||||
|
tracer.update_agent_status(state.agent_id, "running")
|
||||||
|
else:
|
||||||
|
state.resume_from_waiting()
|
||||||
|
has_new_messages = True
|
||||||
|
|
||||||
|
from strix.cli.tracer import get_global_tracer
|
||||||
|
|
||||||
|
tracer = get_global_tracer()
|
||||||
|
if tracer:
|
||||||
|
tracer.update_agent_status(state.agent_id, "running")
|
||||||
|
|
||||||
if sender_id == "user":
|
if sender_id == "user":
|
||||||
sender_name = "User"
|
sender_name = "User"
|
||||||
state.add_message("user", message.get("content", ""))
|
state.add_message("user", message.get("content", ""))
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class AgentState(BaseModel):
|
|||||||
completed: bool = False
|
completed: bool = False
|
||||||
stop_requested: bool = False
|
stop_requested: bool = False
|
||||||
waiting_for_input: bool = False
|
waiting_for_input: bool = False
|
||||||
|
llm_failed: bool = False
|
||||||
final_result: dict[str, Any] | None = None
|
final_result: dict[str, Any] | None = None
|
||||||
|
|
||||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
@@ -85,15 +86,17 @@ class AgentState(BaseModel):
|
|||||||
def is_waiting_for_input(self) -> bool:
|
def is_waiting_for_input(self) -> bool:
|
||||||
return self.waiting_for_input
|
return self.waiting_for_input
|
||||||
|
|
||||||
def enter_waiting_state(self) -> None:
|
def enter_waiting_state(self, llm_failed: bool = False) -> None:
|
||||||
self.waiting_for_input = True
|
self.waiting_for_input = True
|
||||||
self.stop_requested = False
|
self.stop_requested = False
|
||||||
|
self.llm_failed = llm_failed
|
||||||
self.last_updated = datetime.now(UTC).isoformat()
|
self.last_updated = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
def resume_from_waiting(self, new_task: str | None = None) -> None:
|
def resume_from_waiting(self, new_task: str | None = None) -> None:
|
||||||
self.waiting_for_input = False
|
self.waiting_for_input = False
|
||||||
self.stop_requested = False
|
self.stop_requested = False
|
||||||
self.completed = False
|
self.completed = False
|
||||||
|
self.llm_failed = False
|
||||||
if new_task:
|
if new_task:
|
||||||
self.task = new_task
|
self.task = new_task
|
||||||
self.last_updated = datetime.now(UTC).isoformat()
|
self.last_updated = datetime.now(UTC).isoformat()
|
||||||
|
|||||||
@@ -420,6 +420,7 @@ class StrixCLIApp(App): # type: ignore[misc]
|
|||||||
"failed": "❌",
|
"failed": "❌",
|
||||||
"stopped": "⏹️",
|
"stopped": "⏹️",
|
||||||
"stopping": "⏸️",
|
"stopping": "⏸️",
|
||||||
|
"llm_failed": "🔴",
|
||||||
}
|
}
|
||||||
|
|
||||||
status_icon = status_indicators.get(status, "🔵")
|
status_icon = status_indicators.get(status, "🔵")
|
||||||
@@ -544,6 +545,12 @@ class StrixCLIApp(App): # type: ignore[misc]
|
|||||||
self._safe_widget_operation(status_text.update, "Agent completed")
|
self._safe_widget_operation(status_text.update, "Agent completed")
|
||||||
self._safe_widget_operation(keymap_indicator.update, "")
|
self._safe_widget_operation(keymap_indicator.update, "")
|
||||||
self._safe_widget_operation(status_display.remove_class, "hidden")
|
self._safe_widget_operation(status_display.remove_class, "hidden")
|
||||||
|
elif status == "llm_failed":
|
||||||
|
self._safe_widget_operation(status_text.update, "[red]LLM request failed[/red]")
|
||||||
|
self._safe_widget_operation(
|
||||||
|
keymap_indicator.update, "[dim]Send message to retry[/dim]"
|
||||||
|
)
|
||||||
|
self._safe_widget_operation(status_display.remove_class, "hidden")
|
||||||
elif status == "waiting":
|
elif status == "waiting":
|
||||||
animated_text = self._get_animated_waiting_text(self.selected_agent_id)
|
animated_text = self._get_animated_waiting_text(self.selected_agent_id)
|
||||||
self._safe_widget_operation(status_text.update, animated_text)
|
self._safe_widget_operation(status_text.update, animated_text)
|
||||||
@@ -626,7 +633,7 @@ class StrixCLIApp(App): # type: ignore[misc]
|
|||||||
|
|
||||||
for agent_id, agent_data in self.tracer.agents.items():
|
for agent_id, agent_data in self.tracer.agents.items():
|
||||||
status = agent_data.get("status", "running")
|
status = agent_data.get("status", "running")
|
||||||
if status in ["running", "waiting"]:
|
if status in ["running", "waiting", "llm_failed"]:
|
||||||
has_active_agents = True
|
has_active_agents = True
|
||||||
current_dots = self._agent_dot_states.get(agent_id, 0)
|
current_dots = self._agent_dot_states.get(agent_id, 0)
|
||||||
self._agent_dot_states[agent_id] = (current_dots + 1) % 4
|
self._agent_dot_states[agent_id] = (current_dots + 1) % 4
|
||||||
@@ -637,7 +644,7 @@ class StrixCLIApp(App): # type: ignore[misc]
|
|||||||
and self.selected_agent_id in self.tracer.agents
|
and self.selected_agent_id in self.tracer.agents
|
||||||
):
|
):
|
||||||
selected_status = self.tracer.agents[self.selected_agent_id].get("status", "running")
|
selected_status = self.tracer.agents[self.selected_agent_id].get("status", "running")
|
||||||
if selected_status in ["running", "waiting"]:
|
if selected_status in ["running", "waiting", "llm_failed"]:
|
||||||
self._update_agent_status_display()
|
self._update_agent_status_display()
|
||||||
|
|
||||||
if not has_active_agents:
|
if not has_active_agents:
|
||||||
@@ -645,7 +652,7 @@ class StrixCLIApp(App): # type: ignore[misc]
|
|||||||
for agent_id in list(self._agent_dot_states.keys()):
|
for agent_id in list(self._agent_dot_states.keys()):
|
||||||
if agent_id not in self.tracer.agents or self.tracer.agents[agent_id].get(
|
if agent_id not in self.tracer.agents or self.tracer.agents[agent_id].get(
|
||||||
"status"
|
"status"
|
||||||
) not in ["running", "waiting"]:
|
) not in ["running", "waiting", "llm_failed"]:
|
||||||
del self._agent_dot_states[agent_id]
|
del self._agent_dot_states[agent_id]
|
||||||
|
|
||||||
def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]:
|
def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]:
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ class BrowserRenderer(BaseToolRenderer):
|
|||||||
url = args.get("url")
|
url = args.get("url")
|
||||||
text = args.get("text")
|
text = args.get("text")
|
||||||
js_code = args.get("js_code")
|
js_code = args.get("js_code")
|
||||||
|
key = args.get("key")
|
||||||
|
file_path = args.get("file_path")
|
||||||
|
|
||||||
if action in [
|
if action in [
|
||||||
"launch",
|
"launch",
|
||||||
@@ -40,6 +42,8 @@ class BrowserRenderer(BaseToolRenderer):
|
|||||||
"click",
|
"click",
|
||||||
"double_click",
|
"double_click",
|
||||||
"hover",
|
"hover",
|
||||||
|
"press_key",
|
||||||
|
"save_pdf",
|
||||||
]:
|
]:
|
||||||
if action == "launch":
|
if action == "launch":
|
||||||
display_url = cls._format_url(url) if url else None
|
display_url = cls._format_url(url) if url else None
|
||||||
@@ -60,6 +64,12 @@ class BrowserRenderer(BaseToolRenderer):
|
|||||||
message = (
|
message = (
|
||||||
f"executing javascript\n{display_js}" if display_js else "executing javascript"
|
f"executing javascript\n{display_js}" if display_js else "executing javascript"
|
||||||
)
|
)
|
||||||
|
elif action == "press_key":
|
||||||
|
display_key = cls.escape_markup(key) if key else None
|
||||||
|
message = f"pressing key {display_key}" if display_key else "pressing key"
|
||||||
|
elif action == "save_pdf":
|
||||||
|
display_path = cls.escape_markup(file_path) if file_path else None
|
||||||
|
message = f"saving PDF to {display_path}" if display_path else "saving PDF"
|
||||||
else:
|
else:
|
||||||
action_words = {
|
action_words = {
|
||||||
"click": "clicking",
|
"click": "clicking",
|
||||||
@@ -73,11 +83,14 @@ class BrowserRenderer(BaseToolRenderer):
|
|||||||
simple_actions = {
|
simple_actions = {
|
||||||
"back": "going back in browser history",
|
"back": "going back in browser history",
|
||||||
"forward": "going forward in browser history",
|
"forward": "going forward in browser history",
|
||||||
|
"scroll_down": "scrolling down",
|
||||||
|
"scroll_up": "scrolling up",
|
||||||
"refresh": "refreshing browser tab",
|
"refresh": "refreshing browser tab",
|
||||||
"close_tab": "closing browser tab",
|
"close_tab": "closing browser tab",
|
||||||
"switch_tab": "switching browser tab",
|
"switch_tab": "switching browser tab",
|
||||||
"list_tabs": "listing browser tabs",
|
"list_tabs": "listing browser tabs",
|
||||||
"view_source": "viewing page source",
|
"view_source": "viewing page source",
|
||||||
|
"get_console_logs": "getting console logs",
|
||||||
"screenshot": "taking screenshot of browser tab",
|
"screenshot": "taking screenshot of browser tab",
|
||||||
"wait": "waiting...",
|
"wait": "waiting...",
|
||||||
"close": "closing browser",
|
"close": "closing browser",
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ class StrReplaceEditorRenderer(BaseToolRenderer):
|
|||||||
header = "✏️ [bold #10b981]Editing file[/]"
|
header = "✏️ [bold #10b981]Editing file[/]"
|
||||||
elif command == "create":
|
elif command == "create":
|
||||||
header = "📝 [bold #10b981]Creating file[/]"
|
header = "📝 [bold #10b981]Creating file[/]"
|
||||||
|
elif command == "insert":
|
||||||
|
header = "✏️ [bold #10b981]Inserting text[/]"
|
||||||
|
elif command == "undo_edit":
|
||||||
|
header = "↩️ [bold #10b981]Undoing edit[/]"
|
||||||
else:
|
else:
|
||||||
header = "📄 [bold #10b981]File operation[/]"
|
header = "📄 [bold #10b981]File operation[/]"
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
from .llm import LLM
|
from .llm import LLM, LLMRequestFailedError
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLM",
|
"LLM",
|
||||||
"LLMConfig",
|
"LLMConfig",
|
||||||
|
"LLMRequestFailedError",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
litellm._logging._disable_debugging()
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|||||||
@@ -28,6 +28,11 @@ api_key = os.getenv("LLM_API_KEY")
|
|||||||
if api_key:
|
if api_key:
|
||||||
litellm.api_key = api_key
|
litellm.api_key = api_key
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRequestFailedError(Exception):
|
||||||
|
"""Raised when LLM request fails after all retry attempts."""
|
||||||
|
|
||||||
|
|
||||||
MODELS_WITHOUT_STOP_WORDS = [
|
MODELS_WITHOUT_STOP_WORDS = [
|
||||||
"gpt-5",
|
"gpt-5",
|
||||||
"gpt-5-mini",
|
"gpt-5-mini",
|
||||||
@@ -250,15 +255,8 @@ class LLM:
|
|||||||
tool_invocations=tool_invocations if tool_invocations else None,
|
tool_invocations=tool_invocations if tool_invocations else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
except (ValueError, TypeError, RuntimeError):
|
except Exception as e:
|
||||||
logger.exception("Error in LLM generation")
|
raise LLMRequestFailedError("LLM request failed after all retry attempts") from e
|
||||||
return LLMResponse(
|
|
||||||
scan_id=scan_id,
|
|
||||||
step_number=step_number,
|
|
||||||
role=StepRole.AGENT,
|
|
||||||
content="An error occurred while generating the response",
|
|
||||||
tool_invocations=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
||||||
@@ -307,6 +305,7 @@ class LLM:
|
|||||||
"model": self.config.model_name,
|
"model": self.config.model_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
|
"timeout": 180,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._should_include_stop_param():
|
if self._should_include_stop_param():
|
||||||
|
|||||||
@@ -106,10 +106,13 @@ def _summarize_messages(
|
|||||||
completion_args = {
|
completion_args = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"timeout": 180,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = litellm.completion(**completion_args)
|
response = litellm.completion(**completion_args)
|
||||||
summary = response.choices[0].message.content
|
summary = response.choices[0].message.content or ""
|
||||||
|
if not summary.strip():
|
||||||
|
return messages[0]
|
||||||
summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
|
summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
|
||||||
return {
|
return {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ class LLMRequestQueue:
|
|||||||
self._semaphore.release()
|
self._semaphore.release()
|
||||||
|
|
||||||
@retry( # type: ignore[misc]
|
@retry( # type: ignore[misc]
|
||||||
stop=stop_after_attempt(15),
|
stop=stop_after_attempt(5),
|
||||||
wait=wait_exponential(multiplier=1.2, min=1, max=300),
|
wait=wait_exponential(multiplier=2, min=1, max=30),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
||||||
|
|||||||
Reference in New Issue
Block a user