Files
strix/strix/agents/base_agent.py
2025-10-31 21:07:21 +02:00

473 lines
18 KiB
Python

import asyncio
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from strix.telemetry.tracer import Tracer
from jinja2 import (
Environment,
FileSystemLoader,
select_autoescape,
)
from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content
from strix.tools import process_tool_invocations
from .state import AgentState
logger = logging.getLogger(__name__)
class AgentMeta(type):
agent_name: str
jinja_env: Environment
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> type:
new_cls = super().__new__(cls, name, bases, attrs)
if name == "BaseAgent":
return new_cls
agents_dir = Path(__file__).parent
prompt_dir = agents_dir / name
new_cls.agent_name = name
new_cls.jinja_env = Environment(
loader=FileSystemLoader(prompt_dir),
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
)
return new_cls
class BaseAgent(metaclass=AgentMeta):
max_iterations = 200
agent_name: str = ""
jinja_env: Environment
default_llm_config: LLMConfig | None = None
def __init__(self, config: dict[str, Any]):
self.config = config
self.local_source_path = config.get("local_source_path")
self.non_interactive = config.get("non_interactive", False)
if "max_iterations" in config:
self.max_iterations = config["max_iterations"]
self.llm_config_name = config.get("llm_config_name", "default")
self.llm_config = config.get("llm_config", self.default_llm_config)
if self.llm_config is None:
raise ValueError("llm_config is required but not provided")
self.llm = LLM(self.llm_config, agent_name=self.agent_name)
state_from_config = config.get("state")
if state_from_config is not None:
self.state = state_from_config
else:
self.state = AgentState(
agent_name=self.agent_name,
max_iterations=self.max_iterations,
)
self._current_task: asyncio.Task[Any] | None = None
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.log_agent_creation(
agent_id=self.state.agent_id,
name=self.state.agent_name,
task=self.state.task,
parent_id=self.state.parent_id,
)
if self.state.parent_id is None:
scan_config = tracer.scan_config or {}
exec_id = tracer.log_tool_execution_start(
agent_id=self.state.agent_id,
tool_name="scan_start_info",
args=scan_config,
)
tracer.update_tool_execution(execution_id=exec_id, status="completed", result={})
else:
exec_id = tracer.log_tool_execution_start(
agent_id=self.state.agent_id,
tool_name="subagent_start_info",
args={
"name": self.state.agent_name,
"task": self.state.task,
"parent_id": self.state.parent_id,
},
)
tracer.update_tool_execution(execution_id=exec_id, status="completed", result={})
self._add_to_agents_graph()
def _add_to_agents_graph(self) -> None:
from strix.tools.agents_graph import agents_graph_actions
node = {
"id": self.state.agent_id,
"name": self.state.agent_name,
"task": self.state.task,
"status": "running",
"parent_id": self.state.parent_id,
"created_at": self.state.start_time,
"finished_at": None,
"result": None,
"llm_config": self.llm_config_name,
"agent_type": self.__class__.__name__,
"state": self.state.model_dump(),
}
agents_graph_actions._agent_graph["nodes"][self.state.agent_id] = node
agents_graph_actions._agent_instances[self.state.agent_id] = self
agents_graph_actions._agent_states[self.state.agent_id] = self.state
if self.state.parent_id:
agents_graph_actions._agent_graph["edges"].append(
{"from": self.state.parent_id, "to": self.state.agent_id, "type": "delegation"}
)
if self.state.agent_id not in agents_graph_actions._agent_messages:
agents_graph_actions._agent_messages[self.state.agent_id] = []
if self.state.parent_id is None and agents_graph_actions._root_agent_id is None:
agents_graph_actions._root_agent_id = self.state.agent_id
def cancel_current_execution(self) -> None:
if self._current_task and not self._current_task.done():
self._current_task.cancel()
self._current_task = None
async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
await self._initialize_sandbox_and_state(task)
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
while True:
self._check_agent_messages(self.state)
if self.state.is_waiting_for_input():
await self._wait_for_input()
continue
if self.state.should_stop():
if self.non_interactive:
return self.state.final_result or {}
await self._enter_waiting_state(tracer)
continue
if self.state.llm_failed:
await self._wait_for_input()
continue
self.state.increment_iteration()
try:
should_finish = await self._process_iteration(tracer)
if should_finish:
if self.non_interactive:
self.state.set_completed({"success": True})
if tracer:
tracer.update_agent_status(self.state.agent_id, "completed")
return self.state.final_result or {}
await self._enter_waiting_state(tracer, task_completed=True)
continue
except asyncio.CancelledError:
if self.non_interactive:
raise
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
continue
except LLMRequestFailedError as e:
error_msg = str(e)
error_details = getattr(e, "details", None)
self.state.add_error(error_msg)
self.state.enter_waiting_state(llm_failed=True)
if tracer:
tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg)
if error_details:
tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(
tracer._next_execution_id - 1, "failed", error_details
)
continue
except (RuntimeError, ValueError, TypeError) as e:
if not await self._handle_iteration_error(e, tracer):
if self.non_interactive:
self.state.set_completed({"success": False, "error": str(e)})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed")
raise
await self._enter_waiting_state(tracer, error_occurred=True)
continue
async def _wait_for_input(self) -> None:
import asyncio
if self.state.has_waiting_timeout():
self.state.resume_from_waiting()
self.state.add_message("assistant", "Waiting timeout reached. Resuming execution.")
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(self.state.agent_id, "running")
try:
from strix.tools.agents_graph.agents_graph_actions import _agent_graph
if self.state.agent_id in _agent_graph["nodes"]:
_agent_graph["nodes"][self.state.agent_id]["status"] = "running"
except (ImportError, KeyError):
pass
return
await asyncio.sleep(0.5)
async def _enter_waiting_state(
self,
tracer: Optional["Tracer"],
task_completed: bool = False,
error_occurred: bool = False,
was_cancelled: bool = False,
) -> None:
self.state.enter_waiting_state()
if tracer:
if task_completed:
tracer.update_agent_status(self.state.agent_id, "completed")
elif error_occurred:
tracer.update_agent_status(self.state.agent_id, "error")
elif was_cancelled:
tracer.update_agent_status(self.state.agent_id, "stopped")
else:
tracer.update_agent_status(self.state.agent_id, "stopped")
if task_completed:
self.state.add_message(
"assistant",
"Task completed. I'm now waiting for follow-up instructions or new tasks.",
)
elif error_occurred:
self.state.add_message(
"assistant", "An error occurred. I'm now waiting for new instructions."
)
elif was_cancelled:
self.state.add_message(
"assistant", "Execution was cancelled. I'm now waiting for new instructions."
)
else:
self.state.add_message(
"assistant",
"Execution paused. I'm now waiting for new instructions or any updates.",
)
async def _initialize_sandbox_and_state(self, task: str) -> None:
import os
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not sandbox_mode and self.state.sandbox_id is None:
from strix.runtime import get_runtime
runtime = get_runtime()
sandbox_info = await runtime.create_sandbox(
self.state.agent_id, self.state.sandbox_token, self.local_source_path
)
self.state.sandbox_id = sandbox_info["workspace_id"]
self.state.sandbox_token = sandbox_info["auth_token"]
self.state.sandbox_info = sandbox_info
if "agent_id" in sandbox_info:
self.state.sandbox_info["agent_id"] = sandbox_info["agent_id"]
if not self.state.task:
self.state.task = task
self.state.add_message("user", task)
async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
response = await self.llm.generate(self.state.get_conversation_history())
content_stripped = (response.content or "").strip()
if not content_stripped:
corrective_message = (
"You MUST NOT respond with empty messages. "
"If you currently have nothing to do or say, use an appropriate tool instead:\n"
"- Use agents_graph_actions.wait_for_message to wait for messages "
"from user or other agents\n"
"- Use agents_graph_actions.agent_finish if you are a sub-agent "
"and your task is complete\n"
"- Use finish_actions.finish_scan if you are the root/main agent "
"and the scan is complete"
)
self.state.add_message("user", corrective_message)
return False
self.state.add_message("assistant", response.content)
if tracer:
tracer.log_chat_message(
content=clean_content(response.content),
role="assistant",
agent_id=self.state.agent_id,
)
actions = (
response.tool_invocations
if hasattr(response, "tool_invocations") and response.tool_invocations
else []
)
if actions:
return await self._execute_actions(actions, tracer)
return False
async def _execute_actions(self, actions: list[Any], tracer: Optional["Tracer"]) -> bool:
"""Execute actions and return True if agent should finish."""
for action in actions:
self.state.add_action(action)
conversation_history = self.state.get_conversation_history()
tool_task = asyncio.create_task(
process_tool_invocations(actions, conversation_history, self.state)
)
self._current_task = tool_task
try:
should_agent_finish = await tool_task
self._current_task = None
except asyncio.CancelledError:
self._current_task = None
self.state.add_error("Tool execution cancelled by user")
raise
self.state.messages = conversation_history
if should_agent_finish:
self.state.set_completed({"success": True})
if tracer:
tracer.update_agent_status(self.state.agent_id, "completed")
if self.non_interactive and self.state.parent_id is None:
return True
return True
return False
async def _handle_iteration_error(
self,
error: RuntimeError | ValueError | TypeError | asyncio.CancelledError,
tracer: Optional["Tracer"],
) -> bool:
error_msg = f"Error in iteration {self.state.iteration}: {error!s}"
logger.exception(error_msg)
self.state.add_error(error_msg)
if tracer:
tracer.update_agent_status(self.state.agent_id, "error")
return True
def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
try:
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
agent_id = state.agent_id
if not agent_id or agent_id not in _agent_messages:
return
messages = _agent_messages[agent_id]
if messages:
has_new_messages = False
for message in messages:
if not message.get("read", False):
sender_id = message.get("from")
if state.is_waiting_for_input():
if state.llm_failed:
if sender_id == "user":
state.resume_from_waiting()
has_new_messages = True
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(state.agent_id, "running")
else:
state.resume_from_waiting()
has_new_messages = True
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(state.agent_id, "running")
if sender_id == "user":
sender_name = "User"
state.add_message("user", message.get("content", ""))
else:
if sender_id and sender_id in _agent_graph.get("nodes", {}):
sender_name = _agent_graph["nodes"][sender_id]["name"]
message_content = f"""<inter_agent_message>
<delivery_notice>
<important>You have received a message from another agent. You should acknowledge
this message and respond appropriately based on its content. However, DO NOT echo
back or repeat the entire message structure in your response. Simply process the
content and respond naturally as/if needed.</important>
</delivery_notice>
<sender>
<agent_name>{sender_name}</agent_name>
<agent_id>{sender_id}</agent_id>
</sender>
<message_metadata>
<type>{message.get("message_type", "information")}</type>
<priority>{message.get("priority", "normal")}</priority>
<timestamp>{message.get("timestamp", "")}</timestamp>
</message_metadata>
<content>
{message.get("content", "")}
</content>
<delivery_info>
<note>This message was delivered during your task execution.
Please acknowledge and respond if needed.</note>
</delivery_info>
</inter_agent_message>"""
state.add_message("user", message_content.strip())
message["read"] = True
if has_new_messages and not state.is_waiting_for_input():
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
tracer.update_agent_status(agent_id, "running")
except (AttributeError, KeyError, TypeError) as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Error checking agent messages: {e}")
return