import asyncio import contextlib import logging from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from strix.telemetry.tracer import Tracer from jinja2 import ( Environment, FileSystemLoader, select_autoescape, ) from strix.llm import LLM, LLMConfig, LLMRequestFailedError from strix.llm.utils import clean_content from strix.runtime import SandboxInitializationError from strix.tools import process_tool_invocations from strix.utils.resource_paths import get_strix_resource_path from .state import AgentState logger = logging.getLogger(__name__) class AgentMeta(type): agent_name: str jinja_env: Environment def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> type: new_cls = super().__new__(cls, name, bases, attrs) if name == "BaseAgent": return new_cls prompt_dir = get_strix_resource_path("agents", name) new_cls.agent_name = name new_cls.jinja_env = Environment( loader=FileSystemLoader(prompt_dir), autoescape=select_autoescape(enabled_extensions=(), default_for_string=False), ) return new_cls class BaseAgent(metaclass=AgentMeta): max_iterations = 300 agent_name: str = "" jinja_env: Environment default_llm_config: LLMConfig | None = None def __init__(self, config: dict[str, Any]): self.config = config self.local_sources = config.get("local_sources", []) if "max_iterations" in config: self.max_iterations = config["max_iterations"] self.llm_config_name = config.get("llm_config_name", "default") self.llm_config = config.get("llm_config", self.default_llm_config) if self.llm_config is None: raise ValueError("llm_config is required but not provided") state_from_config = config.get("state") if state_from_config is not None: self.state = state_from_config else: self.state = AgentState( agent_name="Root Agent", max_iterations=self.max_iterations, ) self.interactive = getattr(self.llm_config, "interactive", False) if self.interactive and self.state.parent_id is None: self.state.waiting_timeout = 0 self.llm = LLM(self.llm_config, agent_name=self.agent_name) with contextlib.suppress(Exception): self.llm.set_agent_identity(self.state.agent_name, self.state.agent_id) self._current_task: asyncio.Task[Any] | None = None self._force_stop = False from strix.telemetry.tracer import get_global_tracer tracer = get_global_tracer() if tracer: tracer.log_agent_creation( agent_id=self.state.agent_id, name=self.state.agent_name, task=self.state.task, parent_id=self.state.parent_id, ) if self.state.parent_id is None: scan_config = tracer.scan_config or {} exec_id = tracer.log_tool_execution_start( agent_id=self.state.agent_id, tool_name="scan_start_info", args=scan_config, ) tracer.update_tool_execution(execution_id=exec_id, status="completed", result={}) else: exec_id = tracer.log_tool_execution_start( agent_id=self.state.agent_id, tool_name="subagent_start_info", args={ "name": self.state.agent_name, "task": self.state.task, "parent_id": self.state.parent_id, }, ) tracer.update_tool_execution(execution_id=exec_id, status="completed", result={}) self._add_to_agents_graph() def _add_to_agents_graph(self) -> None: from strix.tools.agents_graph import agents_graph_actions node = { "id": self.state.agent_id, "name": self.state.agent_name, "task": self.state.task, "status": "running", "parent_id": self.state.parent_id, "created_at": self.state.start_time, "finished_at": None, "result": None, "llm_config": self.llm_config_name, "agent_type": self.__class__.__name__, "state": self.state.model_dump(), } agents_graph_actions._agent_graph["nodes"][self.state.agent_id] = node agents_graph_actions._agent_instances[self.state.agent_id] = self agents_graph_actions._agent_states[self.state.agent_id] = self.state if self.state.parent_id: agents_graph_actions._agent_graph["edges"].append( {"from": self.state.parent_id, "to": self.state.agent_id, "type": "delegation"} ) if self.state.agent_id not in agents_graph_actions._agent_messages: agents_graph_actions._agent_messages[self.state.agent_id] = [] if self.state.parent_id is None and agents_graph_actions._root_agent_id is None: agents_graph_actions._root_agent_id = self.state.agent_id async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915 from strix.telemetry.tracer import get_global_tracer tracer = get_global_tracer() try: await self._initialize_sandbox_and_state(task) except SandboxInitializationError as e: return self._handle_sandbox_error(e, tracer) while True: if self._force_stop: self._force_stop = False await self._enter_waiting_state(tracer, was_cancelled=True) continue self._check_agent_messages(self.state) if self.state.is_waiting_for_input(): await self._wait_for_input() continue if self.state.should_stop(): if not self.interactive: return self.state.final_result or {} await self._enter_waiting_state(tracer) continue if self.state.llm_failed: await self._wait_for_input() continue self.state.increment_iteration() if ( self.state.is_approaching_max_iterations() and not self.state.max_iterations_warning_sent ): self.state.max_iterations_warning_sent = True remaining = self.state.max_iterations - self.state.iteration warning_msg = ( f"URGENT: You are approaching the maximum iteration limit. " f"Current: {self.state.iteration}/{self.state.max_iterations} " f"({remaining} iterations remaining). " f"Please prioritize completing your required task(s) and calling " f"the appropriate finish tool (finish_scan for root agent, " f"agent_finish for sub-agents) as soon as possible." ) self.state.add_message("user", warning_msg) if self.state.iteration == self.state.max_iterations - 3: final_warning_msg = ( "CRITICAL: You have only 3 iterations left! " "Your next message MUST be the tool call to the appropriate " "finish tool: finish_scan if you are the root agent, or " "agent_finish if you are a sub-agent. " "No other actions should be taken except finishing your work " "immediately." ) self.state.add_message("user", final_warning_msg) try: iteration_task = asyncio.create_task(self._process_iteration(tracer)) self._current_task = iteration_task should_finish = await iteration_task self._current_task = None if should_finish is None and self.interactive: await self._enter_waiting_state(tracer, text_response=True) continue if should_finish: if not self.interactive: self.state.set_completed({"success": True}) if tracer: tracer.update_agent_status(self.state.agent_id, "completed") return self.state.final_result or {} await self._enter_waiting_state(tracer, task_completed=True) continue except asyncio.CancelledError: self._current_task = None if tracer: partial_content = tracer.finalize_streaming_as_interrupted(self.state.agent_id) if partial_content and partial_content.strip(): self.state.add_message( "assistant", f"{partial_content}\n\n[ABORTED BY USER]" ) if not self.interactive: raise await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True) continue except LLMRequestFailedError as e: result = self._handle_llm_error(e, tracer) if result is not None: return result continue except (RuntimeError, ValueError, TypeError) as e: if not await self._handle_iteration_error(e, tracer): if not self.interactive: self.state.set_completed({"success": False, "error": str(e)}) if tracer: tracer.update_agent_status(self.state.agent_id, "failed") raise await self._enter_waiting_state(tracer, error_occurred=True) continue async def _wait_for_input(self) -> None: if self._force_stop: return if self.state.has_waiting_timeout(): self.state.resume_from_waiting() self.state.add_message("user", "Waiting timeout reached. Resuming execution.") from strix.telemetry.tracer import get_global_tracer tracer = get_global_tracer() if tracer: tracer.update_agent_status(self.state.agent_id, "running") try: from strix.tools.agents_graph.agents_graph_actions import _agent_graph if self.state.agent_id in _agent_graph["nodes"]: _agent_graph["nodes"][self.state.agent_id]["status"] = "running" except (ImportError, KeyError): pass return await asyncio.sleep(0.5) async def _enter_waiting_state( self, tracer: Optional["Tracer"], task_completed: bool = False, error_occurred: bool = False, was_cancelled: bool = False, text_response: bool = False, ) -> None: self.state.enter_waiting_state() if tracer: if text_response: tracer.update_agent_status(self.state.agent_id, "waiting_for_input") elif task_completed: tracer.update_agent_status(self.state.agent_id, "completed") elif error_occurred: tracer.update_agent_status(self.state.agent_id, "error") elif was_cancelled: tracer.update_agent_status(self.state.agent_id, "stopped") else: tracer.update_agent_status(self.state.agent_id, "stopped") if text_response: return if task_completed: self.state.add_message( "assistant", "Task completed. I'm now waiting for follow-up instructions or new tasks.", ) elif error_occurred: self.state.add_message( "assistant", "An error occurred. I'm now waiting for new instructions." ) elif was_cancelled: self.state.add_message( "assistant", "Execution was cancelled. I'm now waiting for new instructions." ) else: self.state.add_message( "assistant", "Execution paused. I'm now waiting for new instructions or any updates.", ) async def _initialize_sandbox_and_state(self, task: str) -> None: import os sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" if not sandbox_mode and self.state.sandbox_id is None: from strix.runtime import get_runtime try: runtime = get_runtime() sandbox_info = await runtime.create_sandbox( self.state.agent_id, self.state.sandbox_token, self.local_sources ) self.state.sandbox_id = sandbox_info["workspace_id"] self.state.sandbox_token = sandbox_info["auth_token"] self.state.sandbox_info = sandbox_info if "agent_id" in sandbox_info: self.state.sandbox_info["agent_id"] = sandbox_info["agent_id"] caido_port = sandbox_info.get("caido_port") if caido_port: from strix.telemetry.tracer import get_global_tracer tracer = get_global_tracer() if tracer: tracer.caido_url = f"localhost:{caido_port}" except Exception as e: from strix.telemetry import posthog posthog.error("sandbox_init_error", str(e)) raise if not self.state.task: self.state.task = task self.state.add_message("user", task) async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool | None: final_response = None async for response in self.llm.generate(self.state.get_conversation_history()): final_response = response if tracer and response.content: tracer.update_streaming_content(self.state.agent_id, response.content) if final_response is None: return False content_stripped = (final_response.content or "").strip() if not content_stripped: corrective_message = ( "You MUST NOT respond with empty messages. " "If you currently have nothing to do or say, use an appropriate tool instead:\n" "- Use agents_graph_actions.wait_for_message to wait for messages " "from user or other agents\n" "- Use agents_graph_actions.agent_finish if you are a sub-agent " "and your task is complete\n" "- Use finish_actions.finish_scan if you are the root/main agent " "and the scan is complete" ) self.state.add_message("user", corrective_message) return False thinking_blocks = getattr(final_response, "thinking_blocks", None) self.state.add_message("assistant", final_response.content, thinking_blocks=thinking_blocks) if tracer: tracer.clear_streaming_content(self.state.agent_id) tracer.log_chat_message( content=clean_content(final_response.content), role="assistant", agent_id=self.state.agent_id, ) actions = ( final_response.tool_invocations if hasattr(final_response, "tool_invocations") and final_response.tool_invocations else [] ) if actions: return await self._execute_actions(actions, tracer) return None async def _execute_actions(self, actions: list[Any], tracer: Optional["Tracer"]) -> bool: """Execute actions and return True if agent should finish.""" for action in actions: self.state.add_action(action) conversation_history = self.state.get_conversation_history() tool_task = asyncio.create_task( process_tool_invocations(actions, conversation_history, self.state) ) self._current_task = tool_task try: should_agent_finish = await tool_task self._current_task = None except asyncio.CancelledError: self._current_task = None self.state.add_error("Tool execution cancelled by user") raise self.state.messages = conversation_history if should_agent_finish: self.state.set_completed({"success": True}) if tracer: tracer.update_agent_status(self.state.agent_id, "completed") if not self.interactive and self.state.parent_id is None: return True return True return False def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912 try: from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages agent_id = state.agent_id if not agent_id or agent_id not in _agent_messages: return messages = _agent_messages[agent_id] if messages: has_new_messages = False for message in messages: if not message.get("read", False): sender_id = message.get("from") if state.is_waiting_for_input(): if state.llm_failed: if sender_id == "user": state.resume_from_waiting() has_new_messages = True from strix.telemetry.tracer import get_global_tracer tracer = get_global_tracer() if tracer: tracer.update_agent_status(state.agent_id, "running") else: state.resume_from_waiting() has_new_messages = True from strix.telemetry.tracer import get_global_tracer tracer = get_global_tracer() if tracer: tracer.update_agent_status(state.agent_id, "running") if sender_id == "user": sender_name = "User" state.add_message("user", message.get("content", "")) else: if sender_id and sender_id in _agent_graph.get("nodes", {}): sender_name = _agent_graph["nodes"][sender_id]["name"] message_content = f""" You have received a message from another agent. You should acknowledge this message and respond appropriately based on its content. However, DO NOT echo back or repeat the entire message structure in your response. Simply process the content and respond naturally as/if needed. {sender_name} {sender_id} {message.get("message_type", "information")} {message.get("priority", "normal")} {message.get("timestamp", "")} {message.get("content", "")} This message was delivered during your task execution. Please acknowledge and respond if needed. """ state.add_message("user", message_content.strip()) message["read"] = True if has_new_messages and not state.is_waiting_for_input(): from strix.telemetry.tracer import get_global_tracer tracer = get_global_tracer() if tracer: tracer.update_agent_status(agent_id, "running") except (AttributeError, KeyError, TypeError) as e: import logging logger = logging.getLogger(__name__) logger.warning(f"Error checking agent messages: {e}") return def _handle_sandbox_error( self, error: SandboxInitializationError, tracer: Optional["Tracer"], ) -> dict[str, Any]: error_msg = str(error.message) error_details = error.details self.state.add_error(error_msg) if not self.interactive: self.state.set_completed({"success": False, "error": error_msg}) if tracer: tracer.update_agent_status(self.state.agent_id, "failed", error_msg) if error_details: exec_id = tracer.log_tool_execution_start( self.state.agent_id, "sandbox_error_details", {"error": error_msg, "details": error_details}, ) tracer.update_tool_execution(exec_id, "failed", {"details": error_details}) return {"success": False, "error": error_msg, "details": error_details} self.state.enter_waiting_state() if tracer: tracer.update_agent_status(self.state.agent_id, "sandbox_failed", error_msg) if error_details: exec_id = tracer.log_tool_execution_start( self.state.agent_id, "sandbox_error_details", {"error": error_msg, "details": error_details}, ) tracer.update_tool_execution(exec_id, "failed", {"details": error_details}) return {"success": False, "error": error_msg, "details": error_details} def _handle_llm_error( self, error: LLMRequestFailedError, tracer: Optional["Tracer"], ) -> dict[str, Any] | None: error_msg = str(error) error_details = getattr(error, "details", None) self.state.add_error(error_msg) if not self.interactive: self.state.set_completed({"success": False, "error": error_msg}) if tracer: tracer.update_agent_status(self.state.agent_id, "failed", error_msg) if error_details: exec_id = tracer.log_tool_execution_start( self.state.agent_id, "llm_error_details", {"error": error_msg, "details": error_details}, ) tracer.update_tool_execution(exec_id, "failed", {"details": error_details}) return {"success": False, "error": error_msg} self.state.enter_waiting_state(llm_failed=True) if tracer: tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg) if error_details: exec_id = tracer.log_tool_execution_start( self.state.agent_id, "llm_error_details", {"error": error_msg, "details": error_details}, ) tracer.update_tool_execution(exec_id, "failed", {"details": error_details}) return None async def _handle_iteration_error( self, error: RuntimeError | ValueError | TypeError | asyncio.CancelledError, tracer: Optional["Tracer"], ) -> bool: error_msg = f"Error in iteration {self.state.iteration}: {error!s}" logger.exception(error_msg) self.state.add_error(error_msg) if tracer: tracer.update_agent_status(self.state.agent_id, "error") return True def cancel_current_execution(self) -> None: self._force_stop = True if self._current_task and not self._current_task.done(): try: loop = self._current_task.get_loop() loop.call_soon_threadsafe(self._current_task.cancel) except RuntimeError: self._current_task.cancel() self._current_task = None