feat(agent): implement user interruption handling in agent execution
This commit is contained in:
@@ -148,7 +148,7 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
def cancel_current_execution(self) -> None:
|
||||
if self._current_task and not self._current_task.done():
|
||||
self._current_task.cancel()
|
||||
self._current_task = None
|
||||
self._current_task = None
|
||||
|
||||
async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
|
||||
await self._initialize_sandbox_and_state(task)
|
||||
@@ -204,7 +204,11 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
self.state.add_message("user", final_warning_msg)
|
||||
|
||||
try:
|
||||
should_finish = await self._process_iteration(tracer)
|
||||
iteration_task = asyncio.create_task(self._process_iteration(tracer))
|
||||
self._current_task = iteration_task
|
||||
should_finish = await iteration_task
|
||||
self._current_task = None
|
||||
|
||||
if should_finish:
|
||||
if self.non_interactive:
|
||||
self.state.set_completed({"success": True})
|
||||
@@ -215,6 +219,13 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
continue
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self._current_task = None
|
||||
if tracer:
|
||||
partial_content = tracer.finalize_streaming_as_interrupted(self.state.agent_id)
|
||||
if partial_content and partial_content.strip():
|
||||
self.state.add_message(
|
||||
"assistant", f"{partial_content}\n\n[ABORTED BY USER]"
|
||||
)
|
||||
if self.non_interactive:
|
||||
raise
|
||||
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
|
||||
|
||||
@@ -1161,6 +1161,7 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
def _render_chat_content(self, msg_data: dict[str, Any]) -> Text | None:
|
||||
role = msg_data.get("role")
|
||||
content = msg_data.get("content", "")
|
||||
metadata = msg_data.get("metadata", {})
|
||||
|
||||
if not content:
|
||||
return None
|
||||
@@ -1170,6 +1171,13 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
|
||||
return UserMessageRenderer.render_simple(content)
|
||||
|
||||
if metadata.get("interrupted"):
|
||||
result = self._render_streaming_content(content)
|
||||
result.append("\n")
|
||||
result.append("⚠ ", style="yellow")
|
||||
result.append("Interrupted by user", style="yellow dim")
|
||||
return result
|
||||
|
||||
from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer
|
||||
|
||||
return AgentMessageRenderer.render_simple(content)
|
||||
@@ -1262,6 +1270,28 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
if not self.selected_agent_id:
|
||||
return
|
||||
|
||||
if self.tracer:
|
||||
streaming_content = self.tracer.get_streaming_content(self.selected_agent_id)
|
||||
if streaming_content and streaming_content.strip():
|
||||
self.tracer.clear_streaming_content(self.selected_agent_id)
|
||||
self.tracer.interrupted_content[self.selected_agent_id] = streaming_content
|
||||
self.tracer.log_chat_message(
|
||||
content=streaming_content,
|
||||
role="assistant",
|
||||
agent_id=self.selected_agent_id,
|
||||
metadata={"interrupted": True},
|
||||
)
|
||||
|
||||
try:
|
||||
from strix.tools.agents_graph.agents_graph_actions import _agent_instances
|
||||
|
||||
if self.selected_agent_id in _agent_instances:
|
||||
agent_instance = _agent_instances[self.selected_agent_id]
|
||||
if hasattr(agent_instance, "cancel_current_execution"):
|
||||
agent_instance.cancel_current_execution()
|
||||
except (ImportError, AttributeError, KeyError):
|
||||
pass
|
||||
|
||||
if self.tracer:
|
||||
self.tracer.log_chat_message(
|
||||
content=message,
|
||||
|
||||
@@ -34,6 +34,7 @@ class Tracer:
|
||||
self.tool_executions: dict[int, dict[str, Any]] = {}
|
||||
self.chat_messages: list[dict[str, Any]] = []
|
||||
self.streaming_content: dict[str, str] = {}
|
||||
self.interrupted_content: dict[str, str] = {}
|
||||
|
||||
self.vulnerability_reports: list[dict[str, Any]] = []
|
||||
self.final_scan_result: str | None = None
|
||||
@@ -343,5 +344,19 @@ class Tracer:
|
||||
def get_streaming_content(self, agent_id: str) -> str | None:
|
||||
return self.streaming_content.get(agent_id)
|
||||
|
||||
def finalize_streaming_as_interrupted(self, agent_id: str) -> str | None:
|
||||
content = self.streaming_content.pop(agent_id, None)
|
||||
if content and content.strip():
|
||||
self.interrupted_content[agent_id] = content
|
||||
self.log_chat_message(
|
||||
content=content,
|
||||
role="assistant",
|
||||
agent_id=agent_id,
|
||||
metadata={"interrupted": True},
|
||||
)
|
||||
return content
|
||||
|
||||
return self.interrupted_content.pop(agent_id, None)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
self.save_run_data(mark_complete=True)
|
||||
|
||||
Reference in New Issue
Block a user