Files
strix/strix/agents/base_agent.py
0xallam 740fb3ed40 fix: add timeout handling for Docker operations and improve error messages
- Add SandboxInitializationError exception for sandbox/Docker failures
- Add 60-second timeout to Docker client initialization
- Add _exec_run_with_timeout() method using ThreadPoolExecutor for exec_run calls
- Catch ConnectionError and Timeout exceptions from requests library
- Add _handle_sandbox_error() and _handle_llm_error() methods in base_agent.py
- Handle sandbox_error_details tool in TUI for displaying errors
- Increase TUI truncation limits for better error visibility
- Update all Docker error messages with helpful hint:
  'Please ensure Docker Desktop is installed and running, and try running strix again.'
2026-01-08 17:41:44 -08:00

584 lines
22 KiB
Python

import asyncio
import contextlib
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from strix.telemetry.tracer import Tracer
from jinja2 import (
Environment,
FileSystemLoader,
select_autoescape,
)
from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content
from strix.runtime import SandboxInitializationError
from strix.tools import process_tool_invocations
from .state import AgentState
logger = logging.getLogger(__name__)
class AgentMeta(type):
agent_name: str
jinja_env: Environment
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> type:
new_cls = super().__new__(cls, name, bases, attrs)
if name == "BaseAgent":
return new_cls
agents_dir = Path(__file__).parent
prompt_dir = agents_dir / name
new_cls.agent_name = name
new_cls.jinja_env = Environment(
loader=FileSystemLoader(prompt_dir),
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
)
return new_cls
class BaseAgent(metaclass=AgentMeta):
max_iterations = 300
agent_name: str = ""
jinja_env: Environment
default_llm_config: LLMConfig | None = None
def __init__(self, config: dict[str, Any]):
self.config = config
self.local_sources = config.get("local_sources", [])
self.non_interactive = config.get("non_interactive", False)
if "max_iterations" in config:
self.max_iterations = config["max_iterations"]
self.llm_config_name = config.get("llm_config_name", "default")
self.llm_config = config.get("llm_config", self.default_llm_config)
if self.llm_config is None:
raise ValueError("llm_config is required but not provided")
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
state_from_config = config.get("state")
if state_from_config is not None:
self.state = state_from_config
else:
self.state = AgentState(
agent_name=self.agent_name,
max_iterations=self.max_iterations,
)
with contextlib.suppress(Exception):
self.llm.set_agent_identity(self.agent_name, self.state.agent_id)
self._current_task: asyncio.Task[Any] | None = None
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.log_agent_creation(
agent_id=self.state.agent_id,
name=self.state.agent_name,
task=self.state.task,
parent_id=self.state.parent_id,
)
if self.state.parent_id is None:
scan_config = tracer.scan_config or {}
exec_id = tracer.log_tool_execution_start(
agent_id=self.state.agent_id,
tool_name="scan_start_info",
args=scan_config,
)
tracer.update_tool_execution(execution_id=exec_id, status="completed", result={})
else:
exec_id = tracer.log_tool_execution_start(
agent_id=self.state.agent_id,
tool_name="subagent_start_info",
args={
"name": self.state.agent_name,
"task": self.state.task,
"parent_id": self.state.parent_id,
},
)
tracer.update_tool_execution(execution_id=exec_id, status="completed", result={})
self._add_to_agents_graph()
def _add_to_agents_graph(self) -> None:
from strix.tools.agents_graph import agents_graph_actions
node = {
"id": self.state.agent_id,
"name": self.state.agent_name,
"task": self.state.task,
"status": "running",
"parent_id": self.state.parent_id,
"created_at": self.state.start_time,
"finished_at": None,
"result": None,
"llm_config": self.llm_config_name,
"agent_type": self.__class__.__name__,
"state": self.state.model_dump(),
}
agents_graph_actions._agent_graph["nodes"][self.state.agent_id] = node
agents_graph_actions._agent_instances[self.state.agent_id] = self
agents_graph_actions._agent_states[self.state.agent_id] = self.state
if self.state.parent_id:
agents_graph_actions._agent_graph["edges"].append(
{"from": self.state.parent_id, "to": self.state.agent_id, "type": "delegation"}
)
if self.state.agent_id not in agents_graph_actions._agent_messages:
agents_graph_actions._agent_messages[self.state.agent_id] = []
if self.state.parent_id is None and agents_graph_actions._root_agent_id is None:
agents_graph_actions._root_agent_id = self.state.agent_id
async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
try:
await self._initialize_sandbox_and_state(task)
except SandboxInitializationError as e:
return self._handle_sandbox_error(e, tracer)
while True:
self._check_agent_messages(self.state)
if self.state.is_waiting_for_input():
await self._wait_for_input()
continue
if self.state.should_stop():
if self.non_interactive:
return self.state.final_result or {}
await self._enter_waiting_state(tracer)
continue
if self.state.llm_failed:
await self._wait_for_input()
continue
self.state.increment_iteration()
if (
self.state.is_approaching_max_iterations()
and not self.state.max_iterations_warning_sent
):
self.state.max_iterations_warning_sent = True
remaining = self.state.max_iterations - self.state.iteration
warning_msg = (
f"URGENT: You are approaching the maximum iteration limit. "
f"Current: {self.state.iteration}/{self.state.max_iterations} "
f"({remaining} iterations remaining). "
f"Please prioritize completing your required task(s) and calling "
f"the appropriate finish tool (finish_scan for root agent, "
f"agent_finish for sub-agents) as soon as possible."
)
self.state.add_message("user", warning_msg)
if self.state.iteration == self.state.max_iterations - 3:
final_warning_msg = (
"CRITICAL: You have only 3 iterations left! "
"Your next message MUST be the tool call to the appropriate "
"finish tool: finish_scan if you are the root agent, or "
"agent_finish if you are a sub-agent. "
"No other actions should be taken except finishing your work "
"immediately."
)
self.state.add_message("user", final_warning_msg)
try:
iteration_task = asyncio.create_task(self._process_iteration(tracer))
self._current_task = iteration_task
should_finish = await iteration_task
self._current_task = None
if should_finish:
if self.non_interactive:
self.state.set_completed({"success": True})
if tracer:
tracer.update_agent_status(self.state.agent_id, "completed")
return self.state.final_result or {}
await self._enter_waiting_state(tracer, task_completed=True)
continue
except asyncio.CancelledError:
self._current_task = None
if tracer:
partial_content = tracer.finalize_streaming_as_interrupted(self.state.agent_id)
if partial_content and partial_content.strip():
self.state.add_message(
"assistant", f"{partial_content}\n\n[ABORTED BY USER]"
)
if self.non_interactive:
raise
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
continue
except LLMRequestFailedError as e:
result = self._handle_llm_error(e, tracer)
if result is not None:
return result
continue
except (RuntimeError, ValueError, TypeError) as e:
if not await self._handle_iteration_error(e, tracer):
if self.non_interactive:
self.state.set_completed({"success": False, "error": str(e)})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed")
raise
await self._enter_waiting_state(tracer, error_occurred=True)
continue
async def _wait_for_input(self) -> None:
import asyncio
if self.state.has_waiting_timeout():
self.state.resume_from_waiting()
self.state.add_message("assistant", "Waiting timeout reached. Resuming execution.")
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(self.state.agent_id, "running")
try:
from strix.tools.agents_graph.agents_graph_actions import _agent_graph
if self.state.agent_id in _agent_graph["nodes"]:
_agent_graph["nodes"][self.state.agent_id]["status"] = "running"
except (ImportError, KeyError):
pass
return
await asyncio.sleep(0.5)
async def _enter_waiting_state(
self,
tracer: Optional["Tracer"],
task_completed: bool = False,
error_occurred: bool = False,
was_cancelled: bool = False,
) -> None:
self.state.enter_waiting_state()
if tracer:
if task_completed:
tracer.update_agent_status(self.state.agent_id, "completed")
elif error_occurred:
tracer.update_agent_status(self.state.agent_id, "error")
elif was_cancelled:
tracer.update_agent_status(self.state.agent_id, "stopped")
else:
tracer.update_agent_status(self.state.agent_id, "stopped")
if task_completed:
self.state.add_message(
"assistant",
"Task completed. I'm now waiting for follow-up instructions or new tasks.",
)
elif error_occurred:
self.state.add_message(
"assistant", "An error occurred. I'm now waiting for new instructions."
)
elif was_cancelled:
self.state.add_message(
"assistant", "Execution was cancelled. I'm now waiting for new instructions."
)
else:
self.state.add_message(
"assistant",
"Execution paused. I'm now waiting for new instructions or any updates.",
)
async def _initialize_sandbox_and_state(self, task: str) -> None:
import os
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not sandbox_mode and self.state.sandbox_id is None:
from strix.runtime import get_runtime
runtime = get_runtime()
sandbox_info = await runtime.create_sandbox(
self.state.agent_id, self.state.sandbox_token, self.local_sources
)
self.state.sandbox_id = sandbox_info["workspace_id"]
self.state.sandbox_token = sandbox_info["auth_token"]
self.state.sandbox_info = sandbox_info
if "agent_id" in sandbox_info:
self.state.sandbox_info["agent_id"] = sandbox_info["agent_id"]
if not self.state.task:
self.state.task = task
self.state.add_message("user", task)
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
final_response = None
async for response in self.llm.generate(self.state.get_conversation_history()):
final_response = response
if tracer and response.content:
tracer.update_streaming_content(self.state.agent_id, response.content)
if final_response is None:
return False
content_stripped = (final_response.content or "").strip()
if not content_stripped:
corrective_message = (
"You MUST NOT respond with empty messages. "
"If you currently have nothing to do or say, use an appropriate tool instead:\n"
"- Use agents_graph_actions.wait_for_message to wait for messages "
"from user or other agents\n"
"- Use agents_graph_actions.agent_finish if you are a sub-agent "
"and your task is complete\n"
"- Use finish_actions.finish_scan if you are the root/main agent "
"and the scan is complete"
)
self.state.add_message("user", corrective_message)
return False
self.state.add_message("assistant", final_response.content)
if tracer:
tracer.clear_streaming_content(self.state.agent_id)
tracer.log_chat_message(
content=clean_content(final_response.content),
role="assistant",
agent_id=self.state.agent_id,
)
actions = (
final_response.tool_invocations
if hasattr(final_response, "tool_invocations") and final_response.tool_invocations
else []
)
if actions:
return await self._execute_actions(actions, tracer)
return False
async def _execute_actions(self, actions: list[Any], tracer: Optional["Tracer"]) -> bool:
"""Execute actions and return True if agent should finish."""
for action in actions:
self.state.add_action(action)
conversation_history = self.state.get_conversation_history()
tool_task = asyncio.create_task(
process_tool_invocations(actions, conversation_history, self.state)
)
self._current_task = tool_task
try:
should_agent_finish = await tool_task
self._current_task = None
except asyncio.CancelledError:
self._current_task = None
self.state.add_error("Tool execution cancelled by user")
raise
self.state.messages = conversation_history
if should_agent_finish:
self.state.set_completed({"success": True})
if tracer:
tracer.update_agent_status(self.state.agent_id, "completed")
if self.non_interactive and self.state.parent_id is None:
return True
return True
return False
def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
try:
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
agent_id = state.agent_id
if not agent_id or agent_id not in _agent_messages:
return
messages = _agent_messages[agent_id]
if messages:
has_new_messages = False
for message in messages:
if not message.get("read", False):
sender_id = message.get("from")
if state.is_waiting_for_input():
if state.llm_failed:
if sender_id == "user":
state.resume_from_waiting()
has_new_messages = True
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(state.agent_id, "running")
else:
state.resume_from_waiting()
has_new_messages = True
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(state.agent_id, "running")
if sender_id == "user":
sender_name = "User"
state.add_message("user", message.get("content", ""))
else:
if sender_id and sender_id in _agent_graph.get("nodes", {}):
sender_name = _agent_graph["nodes"][sender_id]["name"]
message_content = f"""<inter_agent_message>
<delivery_notice>
<important>You have received a message from another agent. You should acknowledge
this message and respond appropriately based on its content. However, DO NOT echo
back or repeat the entire message structure in your response. Simply process the
content and respond naturally as/if needed.</important>
</delivery_notice>
<sender>
<agent_name>{sender_name}</agent_name>
<agent_id>{sender_id}</agent_id>
</sender>
<message_metadata>
<type>{message.get("message_type", "information")}</type>
<priority>{message.get("priority", "normal")}</priority>
<timestamp>{message.get("timestamp", "")}</timestamp>
</message_metadata>
<content>
{message.get("content", "")}
</content>
<delivery_info>
<note>This message was delivered during your task execution.
Please acknowledge and respond if needed.</note>
</delivery_info>
</inter_agent_message>"""
state.add_message("user", message_content.strip())
message["read"] = True
if has_new_messages and not state.is_waiting_for_input():
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(agent_id, "running")
except (AttributeError, KeyError, TypeError) as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Error checking agent messages: {e}")
return
def _handle_sandbox_error(
self,
error: SandboxInitializationError,
tracer: Optional["Tracer"],
) -> dict[str, Any]:
error_msg = str(error.message)
error_details = error.details
self.state.add_error(error_msg)
if self.non_interactive:
self.state.set_completed({"success": False, "error": error_msg})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"sandbox_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return {"success": False, "error": error_msg, "details": error_details}
self.state.enter_waiting_state()
if tracer:
tracer.update_agent_status(self.state.agent_id, "sandbox_failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"sandbox_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return {"success": False, "error": error_msg, "details": error_details}
def _handle_llm_error(
self,
error: LLMRequestFailedError,
tracer: Optional["Tracer"],
) -> dict[str, Any] | None:
error_msg = str(error)
error_details = getattr(error, "details", None)
self.state.add_error(error_msg)
if self.non_interactive:
self.state.set_completed({"success": False, "error": error_msg})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return {"success": False, "error": error_msg}
self.state.enter_waiting_state(llm_failed=True)
if tracer:
tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return None
async def _handle_iteration_error(
self,
error: RuntimeError | ValueError | TypeError | asyncio.CancelledError,
tracer: Optional["Tracer"],
) -> bool:
error_msg = f"Error in iteration {self.state.iteration}: {error!s}"
logger.exception(error_msg)
self.state.add_error(error_msg)
if tracer:
tracer.update_agent_status(self.state.agent_id, "error")
return True
def cancel_current_execution(self) -> None:
if self._current_task and not self._current_task.done():
self._current_task.cancel()
self._current_task = None