From 48fb48dba3085b81491ab4f63d314317e3eb6ecd Mon Sep 17 00:00:00 2001 From: 0xallam Date: Mon, 5 Jan 2026 16:06:48 -0800 Subject: [PATCH] feat(agent): implement user interruption handling in agent execution --- strix/agents/base_agent.py | 15 +++++++++++++-- strix/interface/tui.py | 30 ++++++++++++++++++++++++++++++ strix/telemetry/tracer.py | 15 +++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 56ae6b1..2cfd37b 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -148,7 +148,7 @@ class BaseAgent(metaclass=AgentMeta): def cancel_current_execution(self) -> None: if self._current_task and not self._current_task.done(): self._current_task.cancel() - self._current_task = None + self._current_task = None async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915 await self._initialize_sandbox_and_state(task) @@ -204,7 +204,11 @@ class BaseAgent(metaclass=AgentMeta): self.state.add_message("user", final_warning_msg) try: - should_finish = await self._process_iteration(tracer) + iteration_task = asyncio.create_task(self._process_iteration(tracer)) + self._current_task = iteration_task + should_finish = await iteration_task + self._current_task = None + if should_finish: if self.non_interactive: self.state.set_completed({"success": True}) @@ -215,6 +219,13 @@ class BaseAgent(metaclass=AgentMeta): continue except asyncio.CancelledError: + self._current_task = None + if tracer: + partial_content = tracer.finalize_streaming_as_interrupted(self.state.agent_id) + if partial_content and partial_content.strip(): + self.state.add_message( + "assistant", f"{partial_content}\n\n[ABORTED BY USER]" + ) if self.non_interactive: raise await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True) diff --git a/strix/interface/tui.py b/strix/interface/tui.py index 2a0cccc..4d98114 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -1161,6 +1161,7 @@ class StrixTUIApp(App): # type: ignore[misc] def _render_chat_content(self, msg_data: dict[str, Any]) -> Text | None: role = msg_data.get("role") content = msg_data.get("content", "") + metadata = msg_data.get("metadata", {}) if not content: return None @@ -1170,6 +1171,13 @@ class StrixTUIApp(App): # type: ignore[misc] return UserMessageRenderer.render_simple(content) + if metadata.get("interrupted"): + result = self._render_streaming_content(content) + result.append("\n") + result.append("⚠ ", style="yellow") + result.append("Interrupted by user", style="yellow dim") + return result + from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer return AgentMessageRenderer.render_simple(content) @@ -1262,6 +1270,28 @@ class StrixTUIApp(App): # type: ignore[misc] if not self.selected_agent_id: return + if self.tracer: + streaming_content = self.tracer.get_streaming_content(self.selected_agent_id) + if streaming_content and streaming_content.strip(): + self.tracer.clear_streaming_content(self.selected_agent_id) + self.tracer.interrupted_content[self.selected_agent_id] = streaming_content + self.tracer.log_chat_message( + content=streaming_content, + role="assistant", + agent_id=self.selected_agent_id, + metadata={"interrupted": True}, + ) + + try: + from strix.tools.agents_graph.agents_graph_actions import _agent_instances + + if self.selected_agent_id in _agent_instances: + agent_instance = _agent_instances[self.selected_agent_id] + if hasattr(agent_instance, "cancel_current_execution"): + agent_instance.cancel_current_execution() + except (ImportError, AttributeError, KeyError): + pass + if self.tracer: self.tracer.log_chat_message( content=message, diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index 59423f2..8e2b491 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -34,6 +34,7 @@ class Tracer: self.tool_executions: dict[int, dict[str, Any]] = {} self.chat_messages: list[dict[str, Any]] = [] self.streaming_content: dict[str, str] = {} + self.interrupted_content: dict[str, str] = {} self.vulnerability_reports: list[dict[str, Any]] = [] self.final_scan_result: str | None = None @@ -343,5 +344,19 @@ class Tracer: def get_streaming_content(self, agent_id: str) -> str | None: return self.streaming_content.get(agent_id) + def finalize_streaming_as_interrupted(self, agent_id: str) -> str | None: + content = self.streaming_content.pop(agent_id, None) + if content and content.strip(): + self.interrupted_content[agent_id] = content + self.log_chat_message( + content=content, + role="assistant", + agent_id=agent_id, + metadata={"interrupted": True}, + ) + return content + + return self.interrupted_content.pop(agent_id, None) + def cleanup(self) -> None: self.save_run_data(mark_complete=True)