diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index ed793db..72f51a5 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -173,6 +173,33 @@ class BaseAgent(metaclass=AgentMeta): self.state.increment_iteration() + if ( + self.state.is_approaching_max_iterations() + and not self.state.max_iterations_warning_sent + ): + self.state.max_iterations_warning_sent = True + remaining = self.state.max_iterations - self.state.iteration + warning_msg = ( + f"URGENT: You are approaching the maximum iteration limit. " + f"Current: {self.state.iteration}/{self.state.max_iterations} " + f"({remaining} iterations remaining). " + f"Please prioritize completing your required task(s) and calling " + f"the appropriate finish tool (finish_scan for root agent, " + f"agent_finish for sub-agents) as soon as possible." + ) + self.state.add_message("user", warning_msg) + + if self.state.iteration == self.state.max_iterations - 3: + final_warning_msg = ( + "CRITICAL: You have only 3 iterations left! " + "Your next message MUST be the tool call to the appropriate " + "finish tool: finish_scan if you are the root agent, or " + "agent_finish if you are a sub-agent. " + "No other actions should be taken except finishing your work " + "immediately." + ) + self.state.add_message("user", final_warning_msg) + try: should_finish = await self._process_iteration(tracer) if should_finish: diff --git a/strix/agents/state.py b/strix/agents/state.py index 304c2ee..6273757 100644 --- a/strix/agents/state.py +++ b/strix/agents/state.py @@ -26,6 +26,7 @@ class AgentState(BaseModel): llm_failed: bool = False waiting_start_time: datetime | None = None final_result: dict[str, Any] | None = None + max_iterations_warning_sent: bool = False messages: list[dict[str, Any]] = Field(default_factory=list) context: dict[str, Any] = Field(default_factory=dict) @@ -106,6 +107,9 @@ class AgentState(BaseModel): def has_reached_max_iterations(self) -> bool: return self.iteration >= self.max_iterations + def is_approaching_max_iterations(self, threshold: float = 0.85) -> bool: + return self.iteration >= int(self.max_iterations * threshold) + def has_waiting_timeout(self) -> bool: if not self.waiting_for_input or not self.waiting_start_time: return False