feat(agent): implement user interruption handling in agent execution
This commit is contained in:
@@ -148,7 +148,7 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
def cancel_current_execution(self) -> None:
|
def cancel_current_execution(self) -> None:
|
||||||
if self._current_task and not self._current_task.done():
|
if self._current_task and not self._current_task.done():
|
||||||
self._current_task.cancel()
|
self._current_task.cancel()
|
||||||
self._current_task = None
|
self._current_task = None
|
||||||
|
|
||||||
async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
|
async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
|
||||||
await self._initialize_sandbox_and_state(task)
|
await self._initialize_sandbox_and_state(task)
|
||||||
@@ -204,7 +204,11 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
self.state.add_message("user", final_warning_msg)
|
self.state.add_message("user", final_warning_msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
should_finish = await self._process_iteration(tracer)
|
iteration_task = asyncio.create_task(self._process_iteration(tracer))
|
||||||
|
self._current_task = iteration_task
|
||||||
|
should_finish = await iteration_task
|
||||||
|
self._current_task = None
|
||||||
|
|
||||||
if should_finish:
|
if should_finish:
|
||||||
if self.non_interactive:
|
if self.non_interactive:
|
||||||
self.state.set_completed({"success": True})
|
self.state.set_completed({"success": True})
|
||||||
@@ -215,6 +219,13 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
self._current_task = None
|
||||||
|
if tracer:
|
||||||
|
partial_content = tracer.finalize_streaming_as_interrupted(self.state.agent_id)
|
||||||
|
if partial_content and partial_content.strip():
|
||||||
|
self.state.add_message(
|
||||||
|
"assistant", f"{partial_content}\n\n[ABORTED BY USER]"
|
||||||
|
)
|
||||||
if self.non_interactive:
|
if self.non_interactive:
|
||||||
raise
|
raise
|
||||||
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
|
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
|
||||||
|
|||||||
@@ -1161,6 +1161,7 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
def _render_chat_content(self, msg_data: dict[str, Any]) -> Text | None:
|
def _render_chat_content(self, msg_data: dict[str, Any]) -> Text | None:
|
||||||
role = msg_data.get("role")
|
role = msg_data.get("role")
|
||||||
content = msg_data.get("content", "")
|
content = msg_data.get("content", "")
|
||||||
|
metadata = msg_data.get("metadata", {})
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
return None
|
return None
|
||||||
@@ -1170,6 +1171,13 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
|
|
||||||
return UserMessageRenderer.render_simple(content)
|
return UserMessageRenderer.render_simple(content)
|
||||||
|
|
||||||
|
if metadata.get("interrupted"):
|
||||||
|
result = self._render_streaming_content(content)
|
||||||
|
result.append("\n")
|
||||||
|
result.append("⚠ ", style="yellow")
|
||||||
|
result.append("Interrupted by user", style="yellow dim")
|
||||||
|
return result
|
||||||
|
|
||||||
from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer
|
from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer
|
||||||
|
|
||||||
return AgentMessageRenderer.render_simple(content)
|
return AgentMessageRenderer.render_simple(content)
|
||||||
@@ -1262,6 +1270,28 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
if not self.selected_agent_id:
|
if not self.selected_agent_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.tracer:
|
||||||
|
streaming_content = self.tracer.get_streaming_content(self.selected_agent_id)
|
||||||
|
if streaming_content and streaming_content.strip():
|
||||||
|
self.tracer.clear_streaming_content(self.selected_agent_id)
|
||||||
|
self.tracer.interrupted_content[self.selected_agent_id] = streaming_content
|
||||||
|
self.tracer.log_chat_message(
|
||||||
|
content=streaming_content,
|
||||||
|
role="assistant",
|
||||||
|
agent_id=self.selected_agent_id,
|
||||||
|
metadata={"interrupted": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from strix.tools.agents_graph.agents_graph_actions import _agent_instances
|
||||||
|
|
||||||
|
if self.selected_agent_id in _agent_instances:
|
||||||
|
agent_instance = _agent_instances[self.selected_agent_id]
|
||||||
|
if hasattr(agent_instance, "cancel_current_execution"):
|
||||||
|
agent_instance.cancel_current_execution()
|
||||||
|
except (ImportError, AttributeError, KeyError):
|
||||||
|
pass
|
||||||
|
|
||||||
if self.tracer:
|
if self.tracer:
|
||||||
self.tracer.log_chat_message(
|
self.tracer.log_chat_message(
|
||||||
content=message,
|
content=message,
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class Tracer:
|
|||||||
self.tool_executions: dict[int, dict[str, Any]] = {}
|
self.tool_executions: dict[int, dict[str, Any]] = {}
|
||||||
self.chat_messages: list[dict[str, Any]] = []
|
self.chat_messages: list[dict[str, Any]] = []
|
||||||
self.streaming_content: dict[str, str] = {}
|
self.streaming_content: dict[str, str] = {}
|
||||||
|
self.interrupted_content: dict[str, str] = {}
|
||||||
|
|
||||||
self.vulnerability_reports: list[dict[str, Any]] = []
|
self.vulnerability_reports: list[dict[str, Any]] = []
|
||||||
self.final_scan_result: str | None = None
|
self.final_scan_result: str | None = None
|
||||||
@@ -343,5 +344,19 @@ class Tracer:
|
|||||||
def get_streaming_content(self, agent_id: str) -> str | None:
|
def get_streaming_content(self, agent_id: str) -> str | None:
|
||||||
return self.streaming_content.get(agent_id)
|
return self.streaming_content.get(agent_id)
|
||||||
|
|
||||||
|
def finalize_streaming_as_interrupted(self, agent_id: str) -> str | None:
|
||||||
|
content = self.streaming_content.pop(agent_id, None)
|
||||||
|
if content and content.strip():
|
||||||
|
self.interrupted_content[agent_id] = content
|
||||||
|
self.log_chat_message(
|
||||||
|
content=content,
|
||||||
|
role="assistant",
|
||||||
|
agent_id=agent_id,
|
||||||
|
metadata={"interrupted": True},
|
||||||
|
)
|
||||||
|
return content
|
||||||
|
|
||||||
|
return self.interrupted_content.pop(agent_id, None)
|
||||||
|
|
||||||
def cleanup(self) -> None:
|
def cleanup(self) -> None:
|
||||||
self.save_run_data(mark_complete=True)
|
self.save_run_data(mark_complete=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user