Files
strix/strix/agents/state.py
2025-08-08 20:36:44 -07:00

140 lines
4.7 KiB
Python

import uuid
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel, Field
def _generate_agent_id() -> str:
return f"agent_{uuid.uuid4().hex[:8]}"
class AgentState(BaseModel):
agent_id: str = Field(default_factory=_generate_agent_id)
agent_name: str = "Strix Agent"
parent_id: str | None = None
sandbox_id: str | None = None
sandbox_token: str | None = None
sandbox_info: dict[str, Any] | None = None
task: str = ""
iteration: int = 0
max_iterations: int = 200
completed: bool = False
stop_requested: bool = False
waiting_for_input: bool = False
final_result: dict[str, Any] | None = None
messages: list[dict[str, Any]] = Field(default_factory=list)
context: dict[str, Any] = Field(default_factory=dict)
start_time: str = Field(default_factory=lambda: datetime.now(UTC).isoformat())
last_updated: str = Field(default_factory=lambda: datetime.now(UTC).isoformat())
actions_taken: list[dict[str, Any]] = Field(default_factory=list)
observations: list[dict[str, Any]] = Field(default_factory=list)
errors: list[str] = Field(default_factory=list)
def increment_iteration(self) -> None:
self.iteration += 1
self.last_updated = datetime.now(UTC).isoformat()
def add_message(self, role: str, content: Any) -> None:
self.messages.append({"role": role, "content": content})
self.last_updated = datetime.now(UTC).isoformat()
def add_action(self, action: dict[str, Any]) -> None:
self.actions_taken.append(
{
"iteration": self.iteration,
"timestamp": datetime.now(UTC).isoformat(),
"action": action,
}
)
def add_observation(self, observation: dict[str, Any]) -> None:
self.observations.append(
{
"iteration": self.iteration,
"timestamp": datetime.now(UTC).isoformat(),
"observation": observation,
}
)
def add_error(self, error: str) -> None:
self.errors.append(f"Iteration {self.iteration}: {error}")
self.last_updated = datetime.now(UTC).isoformat()
def update_context(self, key: str, value: Any) -> None:
self.context[key] = value
self.last_updated = datetime.now(UTC).isoformat()
def set_completed(self, final_result: dict[str, Any] | None = None) -> None:
self.completed = True
self.final_result = final_result
self.last_updated = datetime.now(UTC).isoformat()
def request_stop(self) -> None:
self.stop_requested = True
self.last_updated = datetime.now(UTC).isoformat()
def should_stop(self) -> bool:
return self.stop_requested or self.completed or self.has_reached_max_iterations()
def is_waiting_for_input(self) -> bool:
return self.waiting_for_input
def enter_waiting_state(self) -> None:
self.waiting_for_input = True
self.stop_requested = False
self.last_updated = datetime.now(UTC).isoformat()
def resume_from_waiting(self, new_task: str | None = None) -> None:
self.waiting_for_input = False
self.stop_requested = False
self.completed = False
if new_task:
self.task = new_task
self.last_updated = datetime.now(UTC).isoformat()
def has_reached_max_iterations(self) -> bool:
return self.iteration >= self.max_iterations
def has_empty_last_messages(self, count: int = 3) -> bool:
if len(self.messages) < count:
return False
last_messages = self.messages[-count:]
for message in last_messages:
content = message.get("content", "")
if isinstance(content, str) and content.strip():
return False
return True
def get_conversation_history(self) -> list[dict[str, Any]]:
return self.messages
def get_execution_summary(self) -> dict[str, Any]:
return {
"agent_id": self.agent_id,
"agent_name": self.agent_name,
"parent_id": self.parent_id,
"sandbox_id": self.sandbox_id,
"sandbox_info": self.sandbox_info,
"task": self.task,
"iteration": self.iteration,
"max_iterations": self.max_iterations,
"completed": self.completed,
"final_result": self.final_result,
"start_time": self.start_time,
"last_updated": self.last_updated,
"total_actions": len(self.actions_taken),
"total_observations": len(self.observations),
"total_errors": len(self.errors),
"has_errors": len(self.errors) > 0,
"max_iterations_reached": self.has_reached_max_iterations() and not self.completed,
}