feat(agent): implement user interruption handling in agent execution

This commit is contained in:
0xallam
2026-01-05 16:06:48 -08:00
committed by Ahmed Allam
parent 0954ac208f
commit 48fb48dba3
3 changed files with 58 additions and 2 deletions

View File

@@ -204,7 +204,11 @@ class BaseAgent(metaclass=AgentMeta):
self.state.add_message("user", final_warning_msg)
try:
should_finish = await self._process_iteration(tracer)
iteration_task = asyncio.create_task(self._process_iteration(tracer))
self._current_task = iteration_task
should_finish = await iteration_task
self._current_task = None
if should_finish:
if self.non_interactive:
self.state.set_completed({"success": True})
@@ -215,6 +219,13 @@ class BaseAgent(metaclass=AgentMeta):
continue
except asyncio.CancelledError:
self._current_task = None
if tracer:
partial_content = tracer.finalize_streaming_as_interrupted(self.state.agent_id)
if partial_content and partial_content.strip():
self.state.add_message(
"assistant", f"{partial_content}\n\n[ABORTED BY USER]"
)
if self.non_interactive:
raise
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)

View File

@@ -1161,6 +1161,7 @@ class StrixTUIApp(App): # type: ignore[misc]
def _render_chat_content(self, msg_data: dict[str, Any]) -> Text | None:
role = msg_data.get("role")
content = msg_data.get("content", "")
metadata = msg_data.get("metadata", {})
if not content:
return None
@@ -1170,6 +1171,13 @@ class StrixTUIApp(App): # type: ignore[misc]
return UserMessageRenderer.render_simple(content)
if metadata.get("interrupted"):
result = self._render_streaming_content(content)
result.append("\n")
result.append("", style="yellow")
result.append("Interrupted by user", style="yellow dim")
return result
from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer
return AgentMessageRenderer.render_simple(content)
@@ -1262,6 +1270,28 @@ class StrixTUIApp(App): # type: ignore[misc]
if not self.selected_agent_id:
return
if self.tracer:
streaming_content = self.tracer.get_streaming_content(self.selected_agent_id)
if streaming_content and streaming_content.strip():
self.tracer.clear_streaming_content(self.selected_agent_id)
self.tracer.interrupted_content[self.selected_agent_id] = streaming_content
self.tracer.log_chat_message(
content=streaming_content,
role="assistant",
agent_id=self.selected_agent_id,
metadata={"interrupted": True},
)
try:
from strix.tools.agents_graph.agents_graph_actions import _agent_instances
if self.selected_agent_id in _agent_instances:
agent_instance = _agent_instances[self.selected_agent_id]
if hasattr(agent_instance, "cancel_current_execution"):
agent_instance.cancel_current_execution()
except (ImportError, AttributeError, KeyError):
pass
if self.tracer:
self.tracer.log_chat_message(
content=message,

View File

@@ -34,6 +34,7 @@ class Tracer:
self.tool_executions: dict[int, dict[str, Any]] = {}
self.chat_messages: list[dict[str, Any]] = []
self.streaming_content: dict[str, str] = {}
self.interrupted_content: dict[str, str] = {}
self.vulnerability_reports: list[dict[str, Any]] = []
self.final_scan_result: str | None = None
@@ -343,5 +344,19 @@ class Tracer:
def get_streaming_content(self, agent_id: str) -> str | None:
return self.streaming_content.get(agent_id)
def finalize_streaming_as_interrupted(self, agent_id: str) -> str | None:
content = self.streaming_content.pop(agent_id, None)
if content and content.strip():
self.interrupted_content[agent_id] = content
self.log_chat_message(
content=content,
role="assistant",
agent_id=agent_id,
metadata={"interrupted": True},
)
return content
return self.interrupted_content.pop(agent_id, None)
def cleanup(self) -> None:
self.save_run_data(mark_complete=True)