import argparse import asyncio import atexit import logging import signal import sys import threading from collections.abc import Callable from importlib.metadata import PackageNotFoundError from importlib.metadata import version as pkg_version from typing import TYPE_CHECKING, Any, ClassVar if TYPE_CHECKING: from textual.timer import Timer from rich.align import Align from rich.console import Group from rich.panel import Panel from rich.style import Style from rich.text import Span, Text from textual import events, on from textual.app import App, ComposeResult from textual.binding import Binding from textual.containers import Grid, Horizontal, Vertical, VerticalScroll from textual.reactive import reactive from textual.screen import ModalScreen from textual.widgets import Button, Label, Static, TextArea, Tree from textual.widgets.tree import TreeNode from strix.agents.StrixAgent import StrixAgent from strix.interface.streaming_parser import parse_streaming_content from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer from strix.interface.tool_components.registry import get_tool_renderer from strix.interface.tool_components.user_message_renderer import UserMessageRenderer from strix.interface.utils import build_tui_stats_text from strix.llm.config import LLMConfig from strix.telemetry.tracer import Tracer, set_global_tracer logger = logging.getLogger(__name__) def get_package_version() -> str: try: return pkg_version("strix-agent") except PackageNotFoundError: return "dev" class ChatTextArea(TextArea): # type: ignore[misc] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._app_reference: StrixTUIApp | None = None def set_app_reference(self, app: "StrixTUIApp") -> None: self._app_reference = app def on_mount(self) -> None: self._update_height() def _on_key(self, event: events.Key) -> None: if event.key == "shift+enter": self.insert("\n") event.prevent_default() return if event.key == "enter" and self._app_reference: text_content = str(self.text) # type: ignore[has-type] message = text_content.strip() if message: self.text = "" self._app_reference._send_user_message(message) event.prevent_default() return super()._on_key(event) @on(TextArea.Changed) # type: ignore[misc] def _update_height(self, _event: TextArea.Changed | None = None) -> None: if not self.parent: return line_count = self.document.line_count target_lines = min(max(1, line_count), 8) new_height = target_lines + 2 if self.parent.styles.height != new_height: self.parent.styles.height = new_height self.scroll_cursor_visible() class SplashScreen(Static): # type: ignore[misc] ALLOW_SELECT = False PRIMARY_GREEN = "#22c55e" BANNER = ( " ███████╗████████╗██████╗ ██╗██╗ ██╗\n" " ██╔════╝╚══██╔══╝██╔══██╗██║╚██╗██╔╝\n" " ███████╗ ██║ ██████╔╝██║ ╚███╔╝\n" " ╚════██║ ██║ ██╔══██╗██║ ██╔██╗\n" " ███████║ ██║ ██║ ██║██║██╔╝ ██╗\n" " ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝" ) def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._animation_step = 0 self._animation_timer: Timer | None = None self._panel_static: Static | None = None self._version = "dev" def compose(self) -> ComposeResult: self._version = get_package_version() self._animation_step = 0 start_line = self._build_start_line_text(self._animation_step) panel = self._build_panel(start_line) panel_static = Static(panel, id="splash_content") self._panel_static = panel_static yield panel_static def on_mount(self) -> None: self._animation_timer = self.set_interval(0.05, self._animate_start_line) def on_unmount(self) -> None: if self._animation_timer is not None: self._animation_timer.stop() self._animation_timer = None def _animate_start_line(self) -> None: if not self._panel_static: return self._animation_step += 1 start_line = self._build_start_line_text(self._animation_step) panel = self._build_panel(start_line) self._panel_static.update(panel) def _build_panel(self, start_line: Text) -> Panel: content = Group( Align.center(Text(self.BANNER.strip("\n"), style=self.PRIMARY_GREEN, justify="center")), Align.center(Text(" ")), Align.center(self._build_welcome_text()), Align.center(self._build_version_text()), Align.center(self._build_tagline_text()), Align.center(Text(" ")), Align.center(start_line.copy()), Align.center(Text(" ")), Align.center(self._build_url_text()), ) return Panel.fit(content, border_style=self.PRIMARY_GREEN, padding=(1, 6)) def _build_url_text(self) -> Text: return Text("strix.ai", style=Style(color=self.PRIMARY_GREEN, bold=True)) def _build_welcome_text(self) -> Text: text = Text("Welcome to ", style=Style(color="white", bold=True)) text.append("Strix", style=Style(color=self.PRIMARY_GREEN, bold=True)) text.append("!", style=Style(color="white", bold=True)) return text def _build_version_text(self) -> Text: return Text(f"v{self._version}", style=Style(color="white", dim=True)) def _build_tagline_text(self) -> Text: return Text("Open-source AI hackers for your apps", style=Style(color="white", dim=True)) def _build_start_line_text(self, phase: int) -> Text: full_text = "Starting Strix Agent" text_len = len(full_text) shine_pos = phase % (text_len + 8) text = Text() for i, char in enumerate(full_text): dist = abs(i - shine_pos) if dist <= 1: style = Style(color="bright_white", bold=True) elif dist <= 3: style = Style(color="white", bold=True) elif dist <= 5: style = Style(color="#a3a3a3") else: style = Style(color="#525252") text.append(char, style=style) return text class HelpScreen(ModalScreen): # type: ignore[misc] def compose(self) -> ComposeResult: yield Grid( Label("Strix Help", id="help_title"), Label( "F1 Help\nCtrl+Q/C Quit\nESC Stop Agent\n" "Enter Send message to agent\nTab Switch panels\n↑/↓ Navigate tree", id="help_content", ), id="dialog", ) def on_key(self, _event: events.Key) -> None: self.app.pop_screen() class StopAgentScreen(ModalScreen): # type: ignore[misc] def __init__(self, agent_name: str, agent_id: str): super().__init__() self.agent_name = agent_name self.agent_id = agent_id def compose(self) -> ComposeResult: yield Grid( Label(f"🛑 Stop '{self.agent_name}'?", id="stop_agent_title"), Grid( Button("Yes", variant="error", id="stop_agent"), Button("No", variant="default", id="cancel_stop"), id="stop_agent_buttons", ), id="stop_agent_dialog", ) def on_mount(self) -> None: cancel_button = self.query_one("#cancel_stop", Button) cancel_button.focus() def on_key(self, event: events.Key) -> None: if event.key in ("left", "right", "up", "down"): focused = self.focused if focused and focused.id == "stop_agent": cancel_button = self.query_one("#cancel_stop", Button) cancel_button.focus() else: stop_button = self.query_one("#stop_agent", Button) stop_button.focus() event.prevent_default() elif event.key == "enter": focused = self.focused if focused and isinstance(focused, Button): focused.press() event.prevent_default() elif event.key == "escape": self.app.pop_screen() event.prevent_default() def on_button_pressed(self, event: Button.Pressed) -> None: self.app.pop_screen() if event.button.id == "stop_agent": self.app.action_confirm_stop_agent(self.agent_id) class VulnerabilityDetailScreen(ModalScreen): # type: ignore[misc] """Modal screen to display vulnerability details.""" SEVERITY_COLORS: ClassVar[dict[str, str]] = { "critical": "#dc2626", # Red "high": "#ea580c", # Orange "medium": "#d97706", # Amber "low": "#22c55e", # Green "info": "#3b82f6", # Blue } FIELD_STYLE: ClassVar[str] = "bold #4ade80" def __init__(self, vulnerability: dict[str, Any]) -> None: super().__init__() self.vulnerability = vulnerability def compose(self) -> ComposeResult: content = self._render_vulnerability() yield Grid( VerticalScroll(Static(content, id="vuln_detail_content"), id="vuln_detail_scroll"), Horizontal( Button("Copy", variant="default", id="copy_vuln_detail"), Button("Done", variant="default", id="close_vuln_detail"), id="vuln_detail_buttons", ), id="vuln_detail_dialog", ) def on_mount(self) -> None: close_button = self.query_one("#close_vuln_detail", Button) close_button.focus() def _get_cvss_color(self, cvss_score: float) -> str: if cvss_score >= 9.0: return "#dc2626" if cvss_score >= 7.0: return "#ea580c" if cvss_score >= 4.0: return "#d97706" if cvss_score >= 0.1: return "#65a30d" return "#6b7280" def _highlight_python(self, code: str) -> Text: try: from pygments.lexers import PythonLexer from pygments.styles import get_style_by_name lexer = PythonLexer() style = get_style_by_name("native") colors = { token: f"#{style_def['color']}" for token, style_def in style if style_def["color"] } text = Text() for token_type, token_value in lexer.get_tokens(code): if not token_value: continue color = None tt = token_type while tt: if tt in colors: color = colors[tt] break tt = tt.parent text.append(token_value, style=color) except (ImportError, KeyError, AttributeError): return Text(code) else: return text def _render_vulnerability(self) -> Text: # noqa: PLR0912, PLR0915 vuln = self.vulnerability text = Text() text.append("🐞 ") text.append("Vulnerability Report", style="bold #ea580c") agent_name = vuln.get("agent_name", "") if agent_name: text.append("\n\n") text.append("Agent: ", style=self.FIELD_STYLE) text.append(agent_name) title = vuln.get("title", "") if title: text.append("\n\n") text.append("Title: ", style=self.FIELD_STYLE) text.append(title) severity = vuln.get("severity", "") if severity: text.append("\n\n") text.append("Severity: ", style=self.FIELD_STYLE) severity_color = self.SEVERITY_COLORS.get(severity.lower(), "#6b7280") text.append(severity.upper(), style=f"bold {severity_color}") cvss_score = vuln.get("cvss") if cvss_score is not None: text.append("\n\n") text.append("CVSS Score: ", style=self.FIELD_STYLE) cvss_color = self._get_cvss_color(float(cvss_score)) text.append(str(cvss_score), style=f"bold {cvss_color}") target = vuln.get("target", "") if target: text.append("\n\n") text.append("Target: ", style=self.FIELD_STYLE) text.append(target) endpoint = vuln.get("endpoint", "") if endpoint: text.append("\n\n") text.append("Endpoint: ", style=self.FIELD_STYLE) text.append(endpoint) method = vuln.get("method", "") if method: text.append("\n\n") text.append("Method: ", style=self.FIELD_STYLE) text.append(method) cve = vuln.get("cve", "") if cve: text.append("\n\n") text.append("CVE: ", style=self.FIELD_STYLE) text.append(cve) # CVSS breakdown cvss_breakdown = vuln.get("cvss_breakdown", {}) if cvss_breakdown: cvss_parts = [] if cvss_breakdown.get("attack_vector"): cvss_parts.append(f"AV:{cvss_breakdown['attack_vector']}") if cvss_breakdown.get("attack_complexity"): cvss_parts.append(f"AC:{cvss_breakdown['attack_complexity']}") if cvss_breakdown.get("privileges_required"): cvss_parts.append(f"PR:{cvss_breakdown['privileges_required']}") if cvss_breakdown.get("user_interaction"): cvss_parts.append(f"UI:{cvss_breakdown['user_interaction']}") if cvss_breakdown.get("scope"): cvss_parts.append(f"S:{cvss_breakdown['scope']}") if cvss_breakdown.get("confidentiality"): cvss_parts.append(f"C:{cvss_breakdown['confidentiality']}") if cvss_breakdown.get("integrity"): cvss_parts.append(f"I:{cvss_breakdown['integrity']}") if cvss_breakdown.get("availability"): cvss_parts.append(f"A:{cvss_breakdown['availability']}") if cvss_parts: text.append("\n\n") text.append("CVSS Vector: ", style=self.FIELD_STYLE) text.append("/".join(cvss_parts), style="dim") description = vuln.get("description", "") if description: text.append("\n\n") text.append("Description", style=self.FIELD_STYLE) text.append("\n") text.append(description) impact = vuln.get("impact", "") if impact: text.append("\n\n") text.append("Impact", style=self.FIELD_STYLE) text.append("\n") text.append(impact) technical_analysis = vuln.get("technical_analysis", "") if technical_analysis: text.append("\n\n") text.append("Technical Analysis", style=self.FIELD_STYLE) text.append("\n") text.append(technical_analysis) poc_description = vuln.get("poc_description", "") if poc_description: text.append("\n\n") text.append("PoC Description", style=self.FIELD_STYLE) text.append("\n") text.append(poc_description) poc_script_code = vuln.get("poc_script_code", "") if poc_script_code: text.append("\n\n") text.append("PoC Code", style=self.FIELD_STYLE) text.append("\n") text.append_text(self._highlight_python(poc_script_code)) remediation_steps = vuln.get("remediation_steps", "") if remediation_steps: text.append("\n\n") text.append("Remediation", style=self.FIELD_STYLE) text.append("\n") text.append(remediation_steps) return text def _get_markdown_report(self) -> str: # noqa: PLR0912, PLR0915 """Get Markdown version of vulnerability report for clipboard.""" vuln = self.vulnerability lines: list[str] = [] # Title title = vuln.get("title", "Untitled Vulnerability") lines.append(f"# {title}") lines.append("") # Metadata if vuln.get("id"): lines.append(f"**ID:** {vuln['id']}") if vuln.get("severity"): lines.append(f"**Severity:** {vuln['severity'].upper()}") if vuln.get("timestamp"): lines.append(f"**Found:** {vuln['timestamp']}") if vuln.get("agent_name"): lines.append(f"**Agent:** {vuln['agent_name']}") if vuln.get("target"): lines.append(f"**Target:** {vuln['target']}") if vuln.get("endpoint"): lines.append(f"**Endpoint:** {vuln['endpoint']}") if vuln.get("method"): lines.append(f"**Method:** {vuln['method']}") if vuln.get("cve"): lines.append(f"**CVE:** {vuln['cve']}") if vuln.get("cvss") is not None: lines.append(f"**CVSS:** {vuln['cvss']}") # CVSS Vector cvss_breakdown = vuln.get("cvss_breakdown", {}) if cvss_breakdown: abbrevs = { "attack_vector": "AV", "attack_complexity": "AC", "privileges_required": "PR", "user_interaction": "UI", "scope": "S", "confidentiality": "C", "integrity": "I", "availability": "A", } parts = [ f"{abbrevs.get(k, k)}:{v}" for k, v in cvss_breakdown.items() if v and k in abbrevs ] if parts: lines.append(f"**CVSS Vector:** {'/'.join(parts)}") # Description lines.append("") lines.append("## Description") lines.append("") lines.append(vuln.get("description") or "No description provided.") # Impact if vuln.get("impact"): lines.extend(["", "## Impact", "", vuln["impact"]]) # Technical Analysis if vuln.get("technical_analysis"): lines.extend(["", "## Technical Analysis", "", vuln["technical_analysis"]]) # Proof of Concept if vuln.get("poc_description") or vuln.get("poc_script_code"): lines.extend(["", "## Proof of Concept", ""]) if vuln.get("poc_description"): lines.append(vuln["poc_description"]) lines.append("") if vuln.get("poc_script_code"): lines.append("```python") lines.append(vuln["poc_script_code"]) lines.append("```") # Code Analysis if vuln.get("code_locations"): lines.extend(["", "## Code Analysis", ""]) for i, loc in enumerate(vuln["code_locations"]): file_ref = loc.get("file", "unknown") line_ref = "" if loc.get("start_line") is not None: if loc.get("end_line") and loc["end_line"] != loc["start_line"]: line_ref = f" (lines {loc['start_line']}-{loc['end_line']})" else: line_ref = f" (line {loc['start_line']})" lines.append(f"**Location {i + 1}:** `{file_ref}`{line_ref}") if loc.get("label"): lines.append(f" {loc['label']}") if loc.get("snippet"): lines.append(f"```\n{loc['snippet']}\n```") if loc.get("fix_before") or loc.get("fix_after"): lines.append("**Suggested Fix:**") lines.append("```diff") if loc.get("fix_before"): lines.extend(f"- {line}" for line in loc["fix_before"].splitlines()) if loc.get("fix_after"): lines.extend(f"+ {line}" for line in loc["fix_after"].splitlines()) lines.append("```") lines.append("") # Remediation if vuln.get("remediation_steps"): lines.extend(["", "## Remediation", "", vuln["remediation_steps"]]) lines.append("") return "\n".join(lines) def on_key(self, event: events.Key) -> None: if event.key == "escape": self.app.pop_screen() event.prevent_default() def on_button_pressed(self, event: Button.Pressed) -> None: if event.button.id == "copy_vuln_detail": markdown_text = self._get_markdown_report() self.app.copy_to_clipboard(markdown_text) copy_button = self.query_one("#copy_vuln_detail", Button) copy_button.label = "Copied!" self.set_timer(1.5, lambda: setattr(copy_button, "label", "Copy")) elif event.button.id == "close_vuln_detail": self.app.pop_screen() class VulnerabilityItem(Static): # type: ignore[misc] """A clickable vulnerability item.""" def __init__(self, label: Text, vuln_data: dict[str, Any], **kwargs: Any) -> None: super().__init__(label, **kwargs) self.vuln_data = vuln_data def on_click(self, _event: events.Click) -> None: """Handle click to open vulnerability detail.""" self.app.push_screen(VulnerabilityDetailScreen(self.vuln_data)) class VulnerabilitiesPanel(VerticalScroll): # type: ignore[misc] """A scrollable panel showing found vulnerabilities with severity-colored dots.""" SEVERITY_COLORS: ClassVar[dict[str, str]] = { "critical": "#dc2626", # Red "high": "#ea580c", # Orange "medium": "#d97706", # Amber "low": "#22c55e", # Green "info": "#3b82f6", # Blue } def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._vulnerabilities: list[dict[str, Any]] = [] def compose(self) -> ComposeResult: return [] def update_vulnerabilities(self, vulnerabilities: list[dict[str, Any]]) -> None: """Update the list of vulnerabilities and re-render.""" if self._vulnerabilities == vulnerabilities: return self._vulnerabilities = list(vulnerabilities) self._render_panel() def _render_panel(self) -> None: """Render the vulnerabilities panel content.""" for child in list(self.children): if isinstance(child, VulnerabilityItem): child.remove() if not self._vulnerabilities: return for vuln in self._vulnerabilities: severity = vuln.get("severity", "info").lower() title = vuln.get("title", "Unknown Vulnerability") color = self.SEVERITY_COLORS.get(severity, "#3b82f6") label = Text() label.append("● ", style=Style(color=color)) label.append(title, style=Style(color="#d4d4d4")) item = VulnerabilityItem(label, vuln, classes="vuln-item") self.mount(item) class QuitScreen(ModalScreen): # type: ignore[misc] def compose(self) -> ComposeResult: yield Grid( Label("Quit Strix?", id="quit_title"), Grid( Button("Yes", variant="error", id="quit"), Button("No", variant="default", id="cancel"), id="quit_buttons", ), id="quit_dialog", ) def on_mount(self) -> None: cancel_button = self.query_one("#cancel", Button) cancel_button.focus() def on_key(self, event: events.Key) -> None: if event.key in ("left", "right", "up", "down"): focused = self.focused if focused and focused.id == "quit": cancel_button = self.query_one("#cancel", Button) cancel_button.focus() else: quit_button = self.query_one("#quit", Button) quit_button.focus() event.prevent_default() elif event.key == "enter": focused = self.focused if focused and isinstance(focused, Button): focused.press() event.prevent_default() elif event.key == "escape": self.app.pop_screen() event.prevent_default() def on_button_pressed(self, event: Button.Pressed) -> None: if event.button.id == "quit": self.app.action_custom_quit() else: self.app.pop_screen() class StrixTUIApp(App): # type: ignore[misc] CSS_PATH = "assets/tui_styles.tcss" ALLOW_SELECT = True SIDEBAR_MIN_WIDTH = 120 selected_agent_id: reactive[str | None] = reactive(default=None) show_splash: reactive[bool] = reactive(default=True) BINDINGS: ClassVar[list[Binding]] = [ Binding("f1", "toggle_help", "Help", priority=True), Binding("ctrl+q", "request_quit", "Quit", priority=True), Binding("ctrl+c", "request_quit", "Quit", priority=True), Binding("escape", "stop_selected_agent", "Stop Agent", priority=True), ] def __init__(self, args: argparse.Namespace): super().__init__() self.args = args self.scan_config = self._build_scan_config(args) self.agent_config = self._build_agent_config(args) self.tracer = Tracer(self.scan_config["run_name"]) self.tracer.set_scan_config(self.scan_config) set_global_tracer(self.tracer) self.agent_nodes: dict[str, TreeNode] = {} self._displayed_agents: set[str] = set() self._displayed_events: list[str] = [] self._streaming_render_cache: dict[str, tuple[int, Any]] = {} self._last_streaming_len: dict[str, int] = {} self._scan_thread: threading.Thread | None = None self._scan_stop_event = threading.Event() self._scan_completed = threading.Event() self._spinner_frame_index: int = 0 # Current animation frame index self._sweep_num_squares: int = 6 # Number of squares in sweep animation self._sweep_colors: list[str] = [ "#000000", # Dimmest (shows dot) "#031a09", "#052e16", "#0d4a2a", "#15803d", "#22c55e", "#4ade80", "#86efac", # Brightest ] self._dot_animation_timer: Any | None = None self._setup_cleanup_handlers() def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]: return { "scan_id": args.run_name, "targets": args.targets_info, "user_instructions": args.instruction or "", "run_name": args.run_name, "diff_scope": getattr(args, "diff_scope", {"active": False}), } def _build_agent_config(self, args: argparse.Namespace) -> dict[str, Any]: scan_mode = getattr(args, "scan_mode", "deep") llm_config = LLMConfig( scan_mode=scan_mode, interactive=True, is_whitebox=bool(getattr(args, "local_sources", [])), ) config = { "llm_config": llm_config, "max_iterations": 300, } if getattr(args, "local_sources", None): config["local_sources"] = args.local_sources return config def _setup_cleanup_handlers(self) -> None: def cleanup_on_exit() -> None: from strix.runtime import cleanup_runtime self.tracer.cleanup() cleanup_runtime() def signal_handler(_signum: int, _frame: Any) -> None: self.tracer.cleanup() sys.exit(0) atexit.register(cleanup_on_exit) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) if hasattr(signal, "SIGHUP"): signal.signal(signal.SIGHUP, signal_handler) def compose(self) -> ComposeResult: if self.show_splash: yield SplashScreen(id="splash_screen") def watch_show_splash(self, show_splash: bool) -> None: if not show_splash and self.is_mounted: try: splash = self.query_one("#splash_screen") splash.remove() except ValueError: pass main_container = Vertical(id="main_container") self.mount(main_container) content_container = Horizontal(id="content_container") main_container.mount(content_container) chat_area_container = Vertical(id="chat_area_container") chat_display = Static("", id="chat_display") chat_history = VerticalScroll(chat_display, id="chat_history") chat_history.can_focus = True status_text = Static("", id="status_text") status_text.ALLOW_SELECT = False keymap_indicator = Static("", id="keymap_indicator") keymap_indicator.ALLOW_SELECT = False agent_status_display = Horizontal( status_text, keymap_indicator, id="agent_status_display", classes="hidden" ) chat_prompt = Static("> ", id="chat_prompt") chat_prompt.ALLOW_SELECT = False chat_input = ChatTextArea( "", id="chat_input", show_line_numbers=False, ) chat_input.set_app_reference(self) chat_input_container = Horizontal(chat_prompt, chat_input, id="chat_input_container") agents_tree = Tree("Agents", id="agents_tree") agents_tree.root.expand() agents_tree.show_root = False agents_tree.show_guide = True agents_tree.guide_depth = 3 agents_tree.guide_style = "dashed" stats_display = Static("", id="stats_display") stats_scroll = VerticalScroll(stats_display, id="stats_scroll") vulnerabilities_panel = VulnerabilitiesPanel(id="vulnerabilities_panel") sidebar = Vertical(agents_tree, vulnerabilities_panel, stats_scroll, id="sidebar") content_container.mount(chat_area_container) content_container.mount(sidebar) chat_area_container.mount(chat_history) chat_area_container.mount(agent_status_display) chat_area_container.mount(chat_input_container) self.call_after_refresh(self._focus_chat_input) def _focus_chat_input(self) -> None: if len(self.screen_stack) > 1 or self.show_splash: return if not self.is_mounted: return try: chat_input = self.query_one("#chat_input", ChatTextArea) chat_input.show_vertical_scrollbar = False chat_input.show_horizontal_scrollbar = False chat_input.focus() except (ValueError, Exception): self.call_after_refresh(self._focus_chat_input) def _focus_agents_tree(self) -> None: if len(self.screen_stack) > 1 or self.show_splash: return if not self.is_mounted: return try: agents_tree = self.query_one("#agents_tree", Tree) agents_tree.focus() if agents_tree.root.children: first_node = agents_tree.root.children[0] agents_tree.select_node(first_node) except (ValueError, Exception): self.call_after_refresh(self._focus_agents_tree) def on_mount(self) -> None: self.title = "strix" self.set_timer(4.5, self._hide_splash_screen) def _hide_splash_screen(self) -> None: self.show_splash = False self._start_scan_thread() self.set_interval(0.35, self._update_ui_from_tracer) def _update_ui_from_tracer(self) -> None: if self.show_splash: return if len(self.screen_stack) > 1: return if not self.is_mounted: return try: chat_history = self.query_one("#chat_history", VerticalScroll) agents_tree = self.query_one("#agents_tree", Tree) if not self._is_widget_safe(chat_history) or not self._is_widget_safe(agents_tree): return except (ValueError, Exception): return agent_updates = False for agent_id, agent_data in list(self.tracer.agents.items()): if agent_id not in self._displayed_agents: self._add_agent_node(agent_data) self._displayed_agents.add(agent_id) agent_updates = True elif self._update_agent_node(agent_id, agent_data): agent_updates = True if agent_updates: self._expand_new_agent_nodes() self._update_chat_view() self._update_agent_status_display() self._update_stats_display() self._update_vulnerabilities_panel() def _update_agent_node(self, agent_id: str, agent_data: dict[str, Any]) -> bool: if agent_id not in self.agent_nodes: return False try: agent_node = self.agent_nodes[agent_id] agent_name_raw = agent_data.get("name", "Agent") status = agent_data.get("status", "running") status_indicators = { "running": "⚪", "waiting": "⏸", "completed": "🟢", "failed": "🔴", "stopped": "■", "stopping": "○", "llm_failed": "🔴", } status_icon = status_indicators.get(status, "○") vuln_count = self._agent_vulnerability_count(agent_id) vuln_indicator = f" ({vuln_count})" if vuln_count > 0 else "" agent_name = f"{status_icon} {agent_name_raw}{vuln_indicator}" if agent_node.label != agent_name: agent_node.set_label(agent_name) return True except (KeyError, AttributeError, ValueError) as e: import logging logging.warning(f"Failed to update agent node label: {e}") return False def _get_chat_content( self, ) -> tuple[Any, str | None]: if not self.selected_agent_id: return self._get_chat_placeholder_content( "Select an agent from the tree to see its activity.", "placeholder-no-agent" ) events = self._gather_agent_events(self.selected_agent_id) streaming = self.tracer.get_streaming_content(self.selected_agent_id) if not events and not streaming: return self._get_chat_placeholder_content( "Starting agent...", "placeholder-no-activity" ) current_event_ids = [e["id"] for e in events] current_streaming_len = len(streaming) if streaming else 0 last_streaming_len = self._last_streaming_len.get(self.selected_agent_id, 0) if ( current_event_ids == self._displayed_events and current_streaming_len == last_streaming_len ): return None, None self._displayed_events = current_event_ids self._last_streaming_len[self.selected_agent_id] = current_streaming_len return self._get_rendered_events_content(events), "chat-content" def _update_chat_view(self) -> None: if len(self.screen_stack) > 1 or self.show_splash or not self.is_mounted: return try: chat_history = self.query_one("#chat_history", VerticalScroll) except (ValueError, Exception): return if not self._is_widget_safe(chat_history): return try: is_at_bottom = chat_history.scroll_y >= chat_history.max_scroll_y except (AttributeError, ValueError): is_at_bottom = True content, css_class = self._get_chat_content() if content is None: return chat_display = self.query_one("#chat_display", Static) self._safe_widget_operation(chat_display.update, content) chat_display.set_classes(css_class) if is_at_bottom: self.call_later(chat_history.scroll_end, animate=False) def _get_chat_placeholder_content( self, message: str, placeholder_class: str ) -> tuple[Text, str]: self._displayed_events = [placeholder_class] text = Text() text.append(message) return text, f"chat-placeholder {placeholder_class}" @staticmethod def _merge_renderables(renderables: list[Any]) -> Text: """Merge renderables into a single Text for mouse text selection support.""" combined = Text() for i, item in enumerate(renderables): if i > 0: combined.append("\n") StrixTUIApp._append_renderable(combined, item) return StrixTUIApp._sanitize_text(combined) @staticmethod def _sanitize_text(text: Text) -> Text: """Clamp spans so Rich/Textual can't crash on malformed offsets.""" plain = text.plain text_length = len(plain) sanitized_spans: list[Span] = [] for span in text.spans: start = max(0, min(span.start, text_length)) end = max(0, min(span.end, text_length)) if end > start: sanitized_spans.append(Span(start, end, span.style)) return Text( plain, style=text.style, justify=text.justify, overflow=text.overflow, no_wrap=text.no_wrap, end=text.end, tab_size=text.tab_size, spans=sanitized_spans, ) @staticmethod def _append_renderable(combined: Text, item: Any) -> None: """Recursively append a renderable's text content to a combined Text.""" if isinstance(item, Text): combined.append_text(StrixTUIApp._sanitize_text(item)) elif isinstance(item, Group): for j, sub in enumerate(item.renderables): if j > 0: combined.append("\n") StrixTUIApp._append_renderable(combined, sub) else: inner = getattr(item, "content", None) or getattr(item, "renderable", None) if inner is not None: StrixTUIApp._append_renderable(combined, inner) else: combined.append(str(item)) def _get_rendered_events_content(self, events: list[dict[str, Any]]) -> Any: renderables: list[Any] = [] if not events: return Text() for event in events: content: Any = None if event["type"] == "chat": content = self._render_chat_content(event["data"]) elif event["type"] == "tool": content = self._render_tool_content_simple(event["data"]) if content: if renderables: renderables.append(Text("")) renderables.append(content) if self.selected_agent_id: streaming = self.tracer.get_streaming_content(self.selected_agent_id) if streaming: streaming_text = self._render_streaming_content(streaming) if streaming_text: if renderables: renderables.append(Text("")) renderables.append(streaming_text) if not renderables: return Text() if len(renderables) == 1 and isinstance(renderables[0], Text): return self._sanitize_text(renderables[0]) return self._merge_renderables(renderables) def _render_streaming_content(self, content: str, agent_id: str | None = None) -> Any: cache_key = agent_id or self.selected_agent_id or "" content_len = len(content) if cache_key in self._streaming_render_cache: cached_len, cached_output = self._streaming_render_cache[cache_key] if cached_len == content_len: return cached_output renderables: list[Any] = [] segments = parse_streaming_content(content) for segment in segments: if segment.type == "text": text_content = AgentMessageRenderer.render_simple(segment.content) if renderables: renderables.append(Text("")) renderables.append(text_content) elif segment.type == "tool": tool_renderable = self._render_streaming_tool( segment.tool_name or "unknown", segment.args or {}, segment.is_complete, ) if renderables: renderables.append(Text("")) renderables.append(tool_renderable) if not renderables: result = Text() elif len(renderables) == 1 and isinstance(renderables[0], Text): result = self._sanitize_text(renderables[0]) else: result = self._merge_renderables(renderables) self._streaming_render_cache[cache_key] = (content_len, result) return result def _render_streaming_tool( self, tool_name: str, args: dict[str, str], is_complete: bool ) -> Any: tool_data = { "tool_name": tool_name, "args": args, "status": "completed" if is_complete else "running", "result": None, } renderer = get_tool_renderer(tool_name) if renderer: widget = renderer.render(tool_data) return widget.content return self._render_default_streaming_tool(tool_name, args, is_complete) def _render_default_streaming_tool( self, tool_name: str, args: dict[str, str], is_complete: bool ) -> Text: text = Text() if is_complete: text.append("✓ ", style="green") else: text.append("● ", style="yellow") text.append("Using tool ", style="dim") text.append(tool_name, style="bold blue") if args: for key, value in list(args.items())[:3]: text.append("\n ") text.append(key, style="dim") text.append(": ") display_value = value if len(value) <= 100 else value[:97] + "..." text.append(display_value, style="italic" if not is_complete else None) return text def _get_status_display_content( self, agent_id: str, agent_data: dict[str, Any] ) -> tuple[Text | None, Text, bool]: status = agent_data.get("status", "running") def keymap_styled(keys: list[tuple[str, str]]) -> Text: t = Text() for i, (key, action) in enumerate(keys): if i > 0: t.append(" · ", style="dim") t.append(key, style="white") t.append(" ", style="dim") t.append(action, style="dim") return t simple_statuses: dict[str, tuple[str, str]] = { "stopping": ("Agent stopping...", ""), "stopped": ("Agent stopped", ""), "completed": ("Agent completed", ""), } if status in simple_statuses: msg, _ = simple_statuses[status] text = Text() text.append(msg) return (text, Text(), False) if status == "llm_failed": error_msg = agent_data.get("error_message", "") text = Text() if error_msg: text.append(error_msg, style="red") else: text.append("LLM request failed", style="red") self._stop_dot_animation() keymap = Text() keymap.append("Send message to retry", style="dim") return (text, keymap, False) if status == "waiting": keymap = Text() keymap.append("Send message to resume", style="dim") return (Text(" "), keymap, False) if status == "running": if self._agent_has_real_activity(agent_id): animated_text = Text() animated_text.append_text(self._get_sweep_animation(self._sweep_colors)) animated_text.append("esc", style="white") animated_text.append(" ", style="dim") animated_text.append("stop", style="dim") return (animated_text, keymap_styled([("ctrl-q", "quit")]), True) animated_text = self._get_animated_verb_text(agent_id, "Initializing") return (animated_text, keymap_styled([("ctrl-q", "quit")]), True) return (None, Text(), False) def _update_agent_status_display(self) -> None: try: status_display = self.query_one("#agent_status_display", Horizontal) status_text = self.query_one("#status_text", Static) keymap_indicator = self.query_one("#keymap_indicator", Static) except (ValueError, Exception): return widgets = [status_display, status_text, keymap_indicator] if not all(self._is_widget_safe(w) for w in widgets): return if not self.selected_agent_id: self._safe_widget_operation(status_display.add_class, "hidden") return try: agent_data = self.tracer.agents[self.selected_agent_id] content, keymap, should_animate = self._get_status_display_content( self.selected_agent_id, agent_data ) if not content: self._safe_widget_operation(status_display.add_class, "hidden") return self._safe_widget_operation(status_text.update, content) self._safe_widget_operation(keymap_indicator.update, keymap) self._safe_widget_operation(status_display.remove_class, "hidden") if should_animate: self._start_dot_animation() except (KeyError, Exception): self._safe_widget_operation(status_display.add_class, "hidden") def _update_stats_display(self) -> None: try: stats_display = self.query_one("#stats_display", Static) except (ValueError, Exception): return if not self._is_widget_safe(stats_display): return if self.screen.selections: return stats_content = Text() stats_text = build_tui_stats_text(self.tracer, self.agent_config) if stats_text: stats_content.append(stats_text) version = get_package_version() stats_content.append(f"\nv{version}", style="white") self._safe_widget_operation(stats_display.update, stats_content) def _update_vulnerabilities_panel(self) -> None: """Update the vulnerabilities panel with current vulnerability data.""" try: vuln_panel = self.query_one("#vulnerabilities_panel", VulnerabilitiesPanel) except (ValueError, Exception): return if not self._is_widget_safe(vuln_panel): return vulnerabilities = self.tracer.vulnerability_reports if not vulnerabilities: self._safe_widget_operation(vuln_panel.add_class, "hidden") return enriched_vulns = [] for vuln in vulnerabilities: enriched = dict(vuln) report_id = vuln.get("id", "") agent_name = self._get_agent_name_for_vulnerability(report_id) if agent_name: enriched["agent_name"] = agent_name enriched_vulns.append(enriched) self._safe_widget_operation(vuln_panel.remove_class, "hidden") vuln_panel.update_vulnerabilities(enriched_vulns) def _get_agent_name_for_vulnerability(self, report_id: str) -> str | None: """Find the agent name that created a vulnerability report.""" for _exec_id, tool_data in list(self.tracer.tool_executions.items()): if tool_data.get("tool_name") == "create_vulnerability_report": result = tool_data.get("result", {}) if isinstance(result, dict) and result.get("report_id") == report_id: agent_id = tool_data.get("agent_id") if agent_id and agent_id in self.tracer.agents: name: str = self.tracer.agents[agent_id].get("name", "Unknown Agent") return name return None def _get_sweep_animation(self, color_palette: list[str]) -> Text: text = Text() num_squares = self._sweep_num_squares num_colors = len(color_palette) offset = num_colors - 1 max_pos = (num_squares - 1) + offset total_range = max_pos + offset cycle_length = total_range * 2 frame_in_cycle = self._spinner_frame_index % cycle_length wave_pos = total_range - abs(total_range - frame_in_cycle) sweep_pos = wave_pos - offset dot_color = "#0a3d1f" for i in range(num_squares): dist = abs(i - sweep_pos) color_idx = max(0, num_colors - 1 - dist) if color_idx == 0: text.append("·", style=Style(color=dot_color)) else: color = color_palette[color_idx] text.append("▪", style=Style(color=color)) text.append(" ") return text def _get_animated_verb_text(self, agent_id: str, verb: str) -> Text: # noqa: ARG002 text = Text() sweep = self._get_sweep_animation(self._sweep_colors) text.append_text(sweep) parts = verb.split(" ", 1) text.append(parts[0], style="white") if len(parts) > 1: text.append(" ", style="dim") text.append(parts[1], style="dim") return text def _start_dot_animation(self) -> None: if self._dot_animation_timer is None: self._dot_animation_timer = self.set_interval(0.06, self._animate_dots) def _stop_dot_animation(self) -> None: if self._dot_animation_timer is not None: self._dot_animation_timer.stop() self._dot_animation_timer = None def _animate_dots(self) -> None: has_active_agents = False if self.selected_agent_id and self.selected_agent_id in self.tracer.agents: agent_data = self.tracer.agents[self.selected_agent_id] status = agent_data.get("status", "running") if status in ["running", "waiting"]: has_active_agents = True num_colors = len(self._sweep_colors) offset = num_colors - 1 max_pos = (self._sweep_num_squares - 1) + offset total_range = max_pos + offset cycle_length = total_range * 2 self._spinner_frame_index = (self._spinner_frame_index + 1) % cycle_length self._update_agent_status_display() if not has_active_agents: has_active_agents = any( agent_data.get("status", "running") in ["running", "waiting"] for agent_data in self.tracer.agents.values() ) if not has_active_agents: self._stop_dot_animation() self._spinner_frame_index = 0 def _agent_has_real_activity(self, agent_id: str) -> bool: initial_tools = {"scan_start_info", "subagent_start_info"} for _exec_id, tool_data in list(self.tracer.tool_executions.items()): if tool_data.get("agent_id") == agent_id: tool_name = tool_data.get("tool_name", "") if tool_name not in initial_tools: return True streaming = self.tracer.get_streaming_content(agent_id) return bool(streaming and streaming.strip()) def _agent_vulnerability_count(self, agent_id: str) -> int: count = 0 for _exec_id, tool_data in list(self.tracer.tool_executions.items()): if tool_data.get("agent_id") == agent_id: tool_name = tool_data.get("tool_name", "") if tool_name == "create_vulnerability_report": status = tool_data.get("status", "") if status == "completed": result = tool_data.get("result", {}) if isinstance(result, dict) and result.get("success"): count += 1 return count def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]: chat_events = [ { "type": "chat", "timestamp": msg["timestamp"], "id": f"chat_{msg['message_id']}", "data": msg, } for msg in self.tracer.chat_messages if msg.get("agent_id") == agent_id ] tool_events = [ { "type": "tool", "timestamp": tool_data["timestamp"], "id": f"tool_{exec_id}", "data": tool_data, } for exec_id, tool_data in list(self.tracer.tool_executions.items()) if tool_data.get("agent_id") == agent_id ] events = chat_events + tool_events events.sort(key=lambda e: (e["timestamp"], e["id"])) return events def watch_selected_agent_id(self, _agent_id: str | None) -> None: if len(self.screen_stack) > 1 or self.show_splash: return if not self.is_mounted: return self._displayed_events.clear() self._streaming_render_cache.clear() self._last_streaming_len.clear() self.call_later(self._update_chat_view) self._update_agent_status_display() def _start_scan_thread(self) -> None: def scan_target() -> None: try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: agent = StrixAgent(self.agent_config) if not self._scan_stop_event.is_set(): loop.run_until_complete(agent.execute_scan(self.scan_config)) except (KeyboardInterrupt, asyncio.CancelledError): logging.info("Scan interrupted by user") except (ConnectionError, TimeoutError): logging.exception("Network error during scan") except RuntimeError: logging.exception("Runtime error during scan") except Exception: logging.exception("Unexpected error during scan") finally: loop.close() self._scan_completed.set() except Exception: logging.exception("Error setting up scan thread") self._scan_completed.set() self._scan_thread = threading.Thread(target=scan_target, daemon=True) self._scan_thread.start() def _add_agent_node(self, agent_data: dict[str, Any]) -> None: if len(self.screen_stack) > 1 or self.show_splash: return if not self.is_mounted: return agent_id = agent_data["id"] parent_id = agent_data.get("parent_id") status = agent_data.get("status", "running") try: agents_tree = self.query_one("#agents_tree", Tree) except (ValueError, Exception): return agent_name_raw = agent_data.get("name", "Agent") status_indicators = { "running": "⚪", "waiting": "⏸", "completed": "🟢", "failed": "🔴", "stopped": "■", "stopping": "○", "llm_failed": "🔴", } status_icon = status_indicators.get(status, "○") vuln_count = self._agent_vulnerability_count(agent_id) vuln_indicator = f" ({vuln_count})" if vuln_count > 0 else "" agent_name = f"{status_icon} {agent_name_raw}{vuln_indicator}" try: if parent_id and parent_id in self.agent_nodes: parent_node = self.agent_nodes[parent_id] agent_node = parent_node.add( agent_name, data={"agent_id": agent_id}, ) parent_node.allow_expand = True else: agent_node = agents_tree.root.add( agent_name, data={"agent_id": agent_id}, ) agent_node.allow_expand = False agent_node.expand() self.agent_nodes[agent_id] = agent_node if len(self.agent_nodes) == 1: agents_tree.select_node(agent_node) self.selected_agent_id = agent_id self._reorganize_orphaned_agents(agent_id) except (AttributeError, ValueError, RuntimeError) as e: import logging logging.warning(f"Failed to add agent node {agent_id}: {e}") def _expand_new_agent_nodes(self) -> None: if len(self.screen_stack) > 1 or self.show_splash: return if not self.is_mounted: return def _expand_all_agent_nodes(self) -> None: if len(self.screen_stack) > 1 or self.show_splash: return if not self.is_mounted: return try: agents_tree = self.query_one("#agents_tree", Tree) self._expand_node_recursively(agents_tree.root) except (ValueError, Exception): logging.debug("Tree not ready for expanding nodes") def _expand_node_recursively(self, node: TreeNode) -> None: if not node.is_expanded: node.expand() for child in node.children: self._expand_node_recursively(child) def _copy_node_under(self, node_to_copy: TreeNode, new_parent: TreeNode) -> None: agent_id = node_to_copy.data["agent_id"] agent_data = self.tracer.agents.get(agent_id, {}) agent_name_raw = agent_data.get("name", "Agent") status = agent_data.get("status", "running") status_indicators = { "running": "⚪", "waiting": "⏸", "completed": "🟢", "failed": "🔴", "stopped": "■", "stopping": "○", "llm_failed": "🔴", } status_icon = status_indicators.get(status, "○") vuln_count = self._agent_vulnerability_count(agent_id) vuln_indicator = f" ({vuln_count})" if vuln_count > 0 else "" agent_name = f"{status_icon} {agent_name_raw}{vuln_indicator}" new_node = new_parent.add( agent_name, data=node_to_copy.data, ) new_node.allow_expand = node_to_copy.allow_expand self.agent_nodes[agent_id] = new_node for child in node_to_copy.children: self._copy_node_under(child, new_node) if node_to_copy.is_expanded: new_node.expand() def _reorganize_orphaned_agents(self, new_parent_id: str) -> None: agents_to_move = [] for agent_id, agent_data in list(self.tracer.agents.items()): if ( agent_data.get("parent_id") == new_parent_id and agent_id in self.agent_nodes and agent_id != new_parent_id ): agents_to_move.append(agent_id) if not agents_to_move: return parent_node = self.agent_nodes[new_parent_id] for child_agent_id in agents_to_move: if child_agent_id in self.agent_nodes: old_node = self.agent_nodes[child_agent_id] if old_node.parent is parent_node: continue self._copy_node_under(old_node, parent_node) old_node.remove() parent_node.allow_expand = True parent_node.expand() def _render_chat_content(self, msg_data: dict[str, Any]) -> Any: role = msg_data.get("role") content = msg_data.get("content", "") metadata = msg_data.get("metadata", {}) if not content: return None if role == "user": return UserMessageRenderer.render_simple(content) if metadata.get("interrupted"): streaming_result = self._render_streaming_content(content) interrupted_text = Text() interrupted_text.append("\n") interrupted_text.append("⚠ ", style="yellow") interrupted_text.append("Interrupted by user", style="yellow dim") return self._merge_renderables([streaming_result, interrupted_text]) return AgentMessageRenderer.render_simple(content) def _render_tool_content_simple(self, tool_data: dict[str, Any]) -> Any: tool_name = tool_data.get("tool_name", "Unknown Tool") args = tool_data.get("args", {}) status = tool_data.get("status", "unknown") result = tool_data.get("result") renderer = get_tool_renderer(tool_name) if renderer: widget = renderer.render(tool_data) return widget.content text = Text() if tool_name in ("llm_error_details", "sandbox_error_details"): return self._render_error_details(text, tool_name, args) text.append("→ Using tool ") text.append(tool_name, style="bold blue") status_styles = { "running": ("●", "yellow"), "completed": ("✓", "green"), "failed": ("✗", "red"), "error": ("✗", "red"), } icon, style = status_styles.get(status, ("○", "dim")) text.append(" ") text.append(icon, style=style) if args: for k, v in list(args.items())[:5]: str_v = str(v) if len(str_v) > 500: str_v = str_v[:497] + "..." text.append("\n ") text.append(k, style="dim") text.append(": ") text.append(str_v) if status in ["completed", "failed", "error"] and result: result_str = str(result) if len(result_str) > 1000: result_str = result_str[:997] + "..." text.append("\n") text.append("Result: ", style="bold") text.append(result_str) return text def _render_error_details(self, text: Any, tool_name: str, args: dict[str, Any]) -> Any: if tool_name == "llm_error_details": text.append("✗ LLM Request Failed", style="red") else: text.append("✗ Sandbox Initialization Failed", style="red") if args.get("error"): text.append(f"\n{args['error']}", style="bold red") if args.get("details"): details = str(args["details"]) if len(details) > 1000: details = details[:997] + "..." text.append("\nDetails: ", style="dim") text.append(details) return text @on(Tree.NodeHighlighted) # type: ignore[misc] def handle_tree_highlight(self, event: Tree.NodeHighlighted) -> None: if len(self.screen_stack) > 1 or self.show_splash: return if not self.is_mounted: return node = event.node try: agents_tree = self.query_one("#agents_tree", Tree) except (ValueError, Exception): return if self.focused == agents_tree and node.data: agent_id = node.data.get("agent_id") if agent_id: self.selected_agent_id = agent_id @on(Tree.NodeSelected) # type: ignore[misc] def handle_tree_node_selected(self, event: Tree.NodeSelected) -> None: if len(self.screen_stack) > 1 or self.show_splash: return if not self.is_mounted: return node = event.node if node.allow_expand: if node.is_expanded: node.collapse() else: node.expand() def _send_user_message(self, message: str) -> None: if not self.selected_agent_id: return if self.tracer: streaming_content = self.tracer.get_streaming_content(self.selected_agent_id) if streaming_content and streaming_content.strip(): self.tracer.clear_streaming_content(self.selected_agent_id) self.tracer.interrupted_content[self.selected_agent_id] = streaming_content self.tracer.log_chat_message( content=streaming_content, role="assistant", agent_id=self.selected_agent_id, metadata={"interrupted": True}, ) try: from strix.tools.agents_graph.agents_graph_actions import _agent_instances if self.selected_agent_id in _agent_instances: agent_instance = _agent_instances[self.selected_agent_id] if hasattr(agent_instance, "cancel_current_execution"): agent_instance.cancel_current_execution() except (ImportError, AttributeError, KeyError): pass if self.tracer: self.tracer.log_chat_message( content=message, role="user", agent_id=self.selected_agent_id, ) try: from strix.tools.agents_graph.agents_graph_actions import send_user_message_to_agent send_user_message_to_agent(self.selected_agent_id, message) except (ImportError, AttributeError) as e: import logging logging.warning(f"Failed to send message to agent {self.selected_agent_id}: {e}") self._displayed_events.clear() self._update_chat_view() self.call_after_refresh(self._focus_chat_input) def _get_agent_name(self, agent_id: str) -> str: try: if self.tracer and agent_id in self.tracer.agents: agent_name = self.tracer.agents[agent_id].get("name") if isinstance(agent_name, str): return agent_name except (KeyError, AttributeError) as e: logging.warning(f"Could not retrieve agent name for {agent_id}: {e}") return "Unknown Agent" def action_toggle_help(self) -> None: if self.show_splash or not self.is_mounted: return try: self.query_one("#main_container") except (ValueError, Exception): return if isinstance(self.screen, HelpScreen): self.pop_screen() return if len(self.screen_stack) > 1: return self.push_screen(HelpScreen()) def action_request_quit(self) -> None: if self.show_splash or not self.is_mounted: self.action_custom_quit() return if len(self.screen_stack) > 1: return try: self.query_one("#main_container") except (ValueError, Exception): self.action_custom_quit() return self.push_screen(QuitScreen()) def action_stop_selected_agent(self) -> None: if self.show_splash or not self.is_mounted: return if len(self.screen_stack) > 1: self.pop_screen() return if not self.selected_agent_id: return agent_name, should_stop = self._validate_agent_for_stopping() if not should_stop: return try: self.query_one("#main_container") except (ValueError, Exception): return self.push_screen(StopAgentScreen(agent_name, self.selected_agent_id)) def _validate_agent_for_stopping(self) -> tuple[str, bool]: agent_name = "Unknown Agent" try: if self.tracer and self.selected_agent_id in self.tracer.agents: agent_data = self.tracer.agents[self.selected_agent_id] agent_name = agent_data.get("name", "Unknown Agent") agent_status = agent_data.get("status", "running") if agent_status not in ["running"]: return agent_name, False agent_events = self._gather_agent_events(self.selected_agent_id) if not agent_events: return agent_name, False return agent_name, True except (KeyError, AttributeError, ValueError) as e: import logging logging.warning(f"Failed to gather agent events: {e}") return agent_name, False def action_confirm_stop_agent(self, agent_id: str) -> None: try: from strix.tools.agents_graph.agents_graph_actions import stop_agent result = stop_agent(agent_id) import logging if result.get("success"): logging.info(f"Stop request sent to agent: {result.get('message', 'Unknown')}") else: logging.warning(f"Failed to stop agent: {result.get('error', 'Unknown error')}") except Exception: import logging logging.exception(f"Failed to stop agent {agent_id}") def action_custom_quit(self) -> None: if self._scan_thread and self._scan_thread.is_alive(): self._scan_stop_event.set() self._scan_thread.join(timeout=1.0) self.tracer.cleanup() self.exit() def _is_widget_safe(self, widget: Any) -> bool: try: _ = widget.screen except (AttributeError, ValueError, Exception): return False else: return bool(widget.is_mounted) def _safe_widget_operation( self, operation: Callable[..., Any], *args: Any, **kwargs: Any ) -> bool: try: operation(*args, **kwargs) except (AttributeError, ValueError, Exception): return False else: return True def on_resize(self, event: events.Resize) -> None: if self.show_splash or not self.is_mounted: return try: sidebar = self.query_one("#sidebar", Vertical) chat_area = self.query_one("#chat_area_container", Vertical) except (ValueError, Exception): return if event.size.width < self.SIDEBAR_MIN_WIDTH: sidebar.add_class("-hidden") chat_area.add_class("-full-width") else: sidebar.remove_class("-hidden") chat_area.remove_class("-full-width") def on_mouse_up(self, _event: events.MouseUp) -> None: self.set_timer(0.05, self._auto_copy_selection) _ICON_PREFIXES: ClassVar[tuple[str, ...]] = ( "🐞 ", "🌐 ", "📋 ", "🧠 ", "◆ ", "◇ ", "◈ ", "→ ", "○ ", "● ", "✓ ", "✗ ", "⚠ ", "▍ ", "▍", "┃ ", "• ", ">_ ", " ", "<~> ", "[ ] ", "[~] ", "[•] ", ) _DECORATIVE_LINES: ClassVar[frozenset[str]] = frozenset( { "● In progress...", "✓ Done", "✗ Failed", "✗ Error", "○ Unknown", } ) @staticmethod def _clean_copied_text(text: str) -> str: lines = text.split("\n") cleaned: list[str] = [] for line in lines: stripped = line.lstrip() if stripped in StrixTUIApp._DECORATIVE_LINES: continue if stripped and all(c == "─" for c in stripped): continue out = line for prefix in StrixTUIApp._ICON_PREFIXES: if stripped.startswith(prefix): leading = line[: len(line) - len(line.lstrip())] out = leading + stripped[len(prefix) :] break cleaned.append(out) return "\n".join(cleaned) def _auto_copy_selection(self) -> None: copied = False try: if self.screen.selections: selected = self.screen.get_selected_text() self.screen.clear_selection() if selected and selected.strip(): cleaned = self._clean_copied_text(selected) self.copy_to_clipboard(cleaned if cleaned.strip() else selected) copied = True except Exception: # noqa: BLE001 logger.debug("Failed to copy screen selection", exc_info=True) if not copied: try: chat_input = self.query_one("#chat_input", ChatTextArea) selected = chat_input.selected_text if selected and selected.strip(): self.copy_to_clipboard(selected) chat_input.move_cursor(chat_input.cursor_location) copied = True except Exception: # noqa: BLE001 logger.debug("Failed to copy chat input selection", exc_info=True) if copied: self.notify("Copied to clipboard", timeout=2) async def run_tui(args: argparse.Namespace) -> None: """Run strix in interactive TUI mode with textual.""" app = StrixTUIApp(args) await app.run_async()