2096 lines
72 KiB
Python
2096 lines
72 KiB
Python
import argparse
|
|
import asyncio
|
|
import atexit
|
|
import logging
|
|
import signal
|
|
import sys
|
|
import threading
|
|
from collections.abc import Callable
|
|
from importlib.metadata import PackageNotFoundError
|
|
from importlib.metadata import version as pkg_version
|
|
from typing import TYPE_CHECKING, Any, ClassVar
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from textual.timer import Timer
|
|
|
|
from rich.align import Align
|
|
from rich.console import Group
|
|
from rich.panel import Panel
|
|
from rich.style import Style
|
|
from rich.text import Span, Text
|
|
from textual import events, on
|
|
from textual.app import App, ComposeResult
|
|
from textual.binding import Binding
|
|
from textual.containers import Grid, Horizontal, Vertical, VerticalScroll
|
|
from textual.reactive import reactive
|
|
from textual.screen import ModalScreen
|
|
from textual.widgets import Button, Label, Static, TextArea, Tree
|
|
from textual.widgets.tree import TreeNode
|
|
|
|
from strix.agents.StrixAgent import StrixAgent
|
|
from strix.interface.streaming_parser import parse_streaming_content
|
|
from strix.interface.tool_components.agent_message_renderer import AgentMessageRenderer
|
|
from strix.interface.tool_components.registry import get_tool_renderer
|
|
from strix.interface.tool_components.user_message_renderer import UserMessageRenderer
|
|
from strix.interface.utils import build_tui_stats_text
|
|
from strix.llm.config import LLMConfig
|
|
from strix.telemetry.tracer import Tracer, set_global_tracer
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_package_version() -> str:
|
|
try:
|
|
return pkg_version("strix-agent")
|
|
except PackageNotFoundError:
|
|
return "dev"
|
|
|
|
|
|
class ChatTextArea(TextArea): # type: ignore[misc]
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self._app_reference: StrixTUIApp | None = None
|
|
|
|
def set_app_reference(self, app: "StrixTUIApp") -> None:
|
|
self._app_reference = app
|
|
|
|
def on_mount(self) -> None:
|
|
self._update_height()
|
|
|
|
def _on_key(self, event: events.Key) -> None:
|
|
if event.key == "shift+enter":
|
|
self.insert("\n")
|
|
event.prevent_default()
|
|
return
|
|
|
|
if event.key == "enter" and self._app_reference:
|
|
text_content = str(self.text) # type: ignore[has-type]
|
|
message = text_content.strip()
|
|
if message:
|
|
self.text = ""
|
|
|
|
self._app_reference._send_user_message(message)
|
|
|
|
event.prevent_default()
|
|
return
|
|
|
|
super()._on_key(event)
|
|
|
|
@on(TextArea.Changed) # type: ignore[misc]
|
|
def _update_height(self, _event: TextArea.Changed | None = None) -> None:
|
|
if not self.parent:
|
|
return
|
|
|
|
line_count = self.document.line_count
|
|
target_lines = min(max(1, line_count), 8)
|
|
|
|
new_height = target_lines + 2
|
|
|
|
if self.parent.styles.height != new_height:
|
|
self.parent.styles.height = new_height
|
|
self.scroll_cursor_visible()
|
|
|
|
|
|
class SplashScreen(Static): # type: ignore[misc]
|
|
ALLOW_SELECT = False
|
|
PRIMARY_GREEN = "#22c55e"
|
|
BANNER = (
|
|
" ███████╗████████╗██████╗ ██╗██╗ ██╗\n"
|
|
" ██╔════╝╚══██╔══╝██╔══██╗██║╚██╗██╔╝\n"
|
|
" ███████╗ ██║ ██████╔╝██║ ╚███╔╝\n"
|
|
" ╚════██║ ██║ ██╔══██╗██║ ██╔██╗\n"
|
|
" ███████║ ██║ ██║ ██║██║██╔╝ ██╗\n"
|
|
" ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝"
|
|
)
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self._animation_step = 0
|
|
self._animation_timer: Timer | None = None
|
|
self._panel_static: Static | None = None
|
|
self._version = "dev"
|
|
|
|
def compose(self) -> ComposeResult:
|
|
self._version = get_package_version()
|
|
self._animation_step = 0
|
|
start_line = self._build_start_line_text(self._animation_step)
|
|
panel = self._build_panel(start_line)
|
|
|
|
panel_static = Static(panel, id="splash_content")
|
|
self._panel_static = panel_static
|
|
yield panel_static
|
|
|
|
def on_mount(self) -> None:
|
|
self._animation_timer = self.set_interval(0.05, self._animate_start_line)
|
|
|
|
def on_unmount(self) -> None:
|
|
if self._animation_timer is not None:
|
|
self._animation_timer.stop()
|
|
self._animation_timer = None
|
|
|
|
def _animate_start_line(self) -> None:
|
|
if not self._panel_static:
|
|
return
|
|
|
|
self._animation_step += 1
|
|
start_line = self._build_start_line_text(self._animation_step)
|
|
panel = self._build_panel(start_line)
|
|
self._panel_static.update(panel)
|
|
|
|
def _build_panel(self, start_line: Text) -> Panel:
|
|
content = Group(
|
|
Align.center(Text(self.BANNER.strip("\n"), style=self.PRIMARY_GREEN, justify="center")),
|
|
Align.center(Text(" ")),
|
|
Align.center(self._build_welcome_text()),
|
|
Align.center(self._build_version_text()),
|
|
Align.center(self._build_tagline_text()),
|
|
Align.center(Text(" ")),
|
|
Align.center(start_line.copy()),
|
|
Align.center(Text(" ")),
|
|
Align.center(self._build_url_text()),
|
|
)
|
|
|
|
return Panel.fit(content, border_style=self.PRIMARY_GREEN, padding=(1, 6))
|
|
|
|
def _build_url_text(self) -> Text:
|
|
return Text("strix.ai", style=Style(color=self.PRIMARY_GREEN, bold=True))
|
|
|
|
def _build_welcome_text(self) -> Text:
|
|
text = Text("Welcome to ", style=Style(color="white", bold=True))
|
|
text.append("Strix", style=Style(color=self.PRIMARY_GREEN, bold=True))
|
|
text.append("!", style=Style(color="white", bold=True))
|
|
return text
|
|
|
|
def _build_version_text(self) -> Text:
|
|
return Text(f"v{self._version}", style=Style(color="white", dim=True))
|
|
|
|
def _build_tagline_text(self) -> Text:
|
|
return Text("Open-source AI hackers for your apps", style=Style(color="white", dim=True))
|
|
|
|
def _build_start_line_text(self, phase: int) -> Text:
|
|
full_text = "Starting Strix Agent"
|
|
text_len = len(full_text)
|
|
|
|
shine_pos = phase % (text_len + 8)
|
|
|
|
text = Text()
|
|
for i, char in enumerate(full_text):
|
|
dist = abs(i - shine_pos)
|
|
|
|
if dist <= 1:
|
|
style = Style(color="bright_white", bold=True)
|
|
elif dist <= 3:
|
|
style = Style(color="white", bold=True)
|
|
elif dist <= 5:
|
|
style = Style(color="#a3a3a3")
|
|
else:
|
|
style = Style(color="#525252")
|
|
|
|
text.append(char, style=style)
|
|
|
|
return text
|
|
|
|
|
|
class HelpScreen(ModalScreen): # type: ignore[misc]
|
|
def compose(self) -> ComposeResult:
|
|
yield Grid(
|
|
Label("Strix Help", id="help_title"),
|
|
Label(
|
|
"F1 Help\nCtrl+Q/C Quit\nESC Stop Agent\n"
|
|
"Enter Send message to agent\nTab Switch panels\n↑/↓ Navigate tree",
|
|
id="help_content",
|
|
),
|
|
id="dialog",
|
|
)
|
|
|
|
def on_key(self, _event: events.Key) -> None:
|
|
self.app.pop_screen()
|
|
|
|
|
|
class StopAgentScreen(ModalScreen): # type: ignore[misc]
|
|
def __init__(self, agent_name: str, agent_id: str):
|
|
super().__init__()
|
|
self.agent_name = agent_name
|
|
self.agent_id = agent_id
|
|
|
|
def compose(self) -> ComposeResult:
|
|
yield Grid(
|
|
Label(f"🛑 Stop '{self.agent_name}'?", id="stop_agent_title"),
|
|
Grid(
|
|
Button("Yes", variant="error", id="stop_agent"),
|
|
Button("No", variant="default", id="cancel_stop"),
|
|
id="stop_agent_buttons",
|
|
),
|
|
id="stop_agent_dialog",
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
cancel_button = self.query_one("#cancel_stop", Button)
|
|
cancel_button.focus()
|
|
|
|
def on_key(self, event: events.Key) -> None:
|
|
if event.key in ("left", "right", "up", "down"):
|
|
focused = self.focused
|
|
|
|
if focused and focused.id == "stop_agent":
|
|
cancel_button = self.query_one("#cancel_stop", Button)
|
|
cancel_button.focus()
|
|
else:
|
|
stop_button = self.query_one("#stop_agent", Button)
|
|
stop_button.focus()
|
|
|
|
event.prevent_default()
|
|
elif event.key == "enter":
|
|
focused = self.focused
|
|
if focused and isinstance(focused, Button):
|
|
focused.press()
|
|
event.prevent_default()
|
|
elif event.key == "escape":
|
|
self.app.pop_screen()
|
|
event.prevent_default()
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
self.app.pop_screen()
|
|
if event.button.id == "stop_agent":
|
|
self.app.action_confirm_stop_agent(self.agent_id)
|
|
|
|
|
|
class VulnerabilityDetailScreen(ModalScreen): # type: ignore[misc]
|
|
"""Modal screen to display vulnerability details."""
|
|
|
|
SEVERITY_COLORS: ClassVar[dict[str, str]] = {
|
|
"critical": "#dc2626", # Red
|
|
"high": "#ea580c", # Orange
|
|
"medium": "#d97706", # Amber
|
|
"low": "#22c55e", # Green
|
|
"info": "#3b82f6", # Blue
|
|
}
|
|
|
|
FIELD_STYLE: ClassVar[str] = "bold #4ade80"
|
|
|
|
def __init__(self, vulnerability: dict[str, Any]) -> None:
|
|
super().__init__()
|
|
self.vulnerability = vulnerability
|
|
|
|
def compose(self) -> ComposeResult:
|
|
content = self._render_vulnerability()
|
|
yield Grid(
|
|
VerticalScroll(Static(content, id="vuln_detail_content"), id="vuln_detail_scroll"),
|
|
Horizontal(
|
|
Button("Copy", variant="default", id="copy_vuln_detail"),
|
|
Button("Done", variant="default", id="close_vuln_detail"),
|
|
id="vuln_detail_buttons",
|
|
),
|
|
id="vuln_detail_dialog",
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
close_button = self.query_one("#close_vuln_detail", Button)
|
|
close_button.focus()
|
|
|
|
def _get_cvss_color(self, cvss_score: float) -> str:
|
|
if cvss_score >= 9.0:
|
|
return "#dc2626"
|
|
if cvss_score >= 7.0:
|
|
return "#ea580c"
|
|
if cvss_score >= 4.0:
|
|
return "#d97706"
|
|
if cvss_score >= 0.1:
|
|
return "#65a30d"
|
|
return "#6b7280"
|
|
|
|
def _highlight_python(self, code: str) -> Text:
|
|
try:
|
|
from pygments.lexers import PythonLexer
|
|
from pygments.styles import get_style_by_name
|
|
|
|
lexer = PythonLexer()
|
|
style = get_style_by_name("native")
|
|
colors = {
|
|
token: f"#{style_def['color']}" for token, style_def in style if style_def["color"]
|
|
}
|
|
|
|
text = Text()
|
|
for token_type, token_value in lexer.get_tokens(code):
|
|
if not token_value:
|
|
continue
|
|
color = None
|
|
tt = token_type
|
|
while tt:
|
|
if tt in colors:
|
|
color = colors[tt]
|
|
break
|
|
tt = tt.parent
|
|
text.append(token_value, style=color)
|
|
except (ImportError, KeyError, AttributeError):
|
|
return Text(code)
|
|
else:
|
|
return text
|
|
|
|
def _render_vulnerability(self) -> Text: # noqa: PLR0912, PLR0915
|
|
vuln = self.vulnerability
|
|
text = Text()
|
|
|
|
text.append("🐞 ")
|
|
text.append("Vulnerability Report", style="bold #ea580c")
|
|
|
|
agent_name = vuln.get("agent_name", "")
|
|
if agent_name:
|
|
text.append("\n\n")
|
|
text.append("Agent: ", style=self.FIELD_STYLE)
|
|
text.append(agent_name)
|
|
|
|
title = vuln.get("title", "")
|
|
if title:
|
|
text.append("\n\n")
|
|
text.append("Title: ", style=self.FIELD_STYLE)
|
|
text.append(title)
|
|
|
|
severity = vuln.get("severity", "")
|
|
if severity:
|
|
text.append("\n\n")
|
|
text.append("Severity: ", style=self.FIELD_STYLE)
|
|
severity_color = self.SEVERITY_COLORS.get(severity.lower(), "#6b7280")
|
|
text.append(severity.upper(), style=f"bold {severity_color}")
|
|
|
|
cvss_score = vuln.get("cvss")
|
|
if cvss_score is not None:
|
|
text.append("\n\n")
|
|
text.append("CVSS Score: ", style=self.FIELD_STYLE)
|
|
cvss_color = self._get_cvss_color(float(cvss_score))
|
|
text.append(str(cvss_score), style=f"bold {cvss_color}")
|
|
|
|
target = vuln.get("target", "")
|
|
if target:
|
|
text.append("\n\n")
|
|
text.append("Target: ", style=self.FIELD_STYLE)
|
|
text.append(target)
|
|
|
|
endpoint = vuln.get("endpoint", "")
|
|
if endpoint:
|
|
text.append("\n\n")
|
|
text.append("Endpoint: ", style=self.FIELD_STYLE)
|
|
text.append(endpoint)
|
|
|
|
method = vuln.get("method", "")
|
|
if method:
|
|
text.append("\n\n")
|
|
text.append("Method: ", style=self.FIELD_STYLE)
|
|
text.append(method)
|
|
|
|
cve = vuln.get("cve", "")
|
|
if cve:
|
|
text.append("\n\n")
|
|
text.append("CVE: ", style=self.FIELD_STYLE)
|
|
text.append(cve)
|
|
|
|
# CVSS breakdown
|
|
cvss_breakdown = vuln.get("cvss_breakdown", {})
|
|
if cvss_breakdown:
|
|
cvss_parts = []
|
|
if cvss_breakdown.get("attack_vector"):
|
|
cvss_parts.append(f"AV:{cvss_breakdown['attack_vector']}")
|
|
if cvss_breakdown.get("attack_complexity"):
|
|
cvss_parts.append(f"AC:{cvss_breakdown['attack_complexity']}")
|
|
if cvss_breakdown.get("privileges_required"):
|
|
cvss_parts.append(f"PR:{cvss_breakdown['privileges_required']}")
|
|
if cvss_breakdown.get("user_interaction"):
|
|
cvss_parts.append(f"UI:{cvss_breakdown['user_interaction']}")
|
|
if cvss_breakdown.get("scope"):
|
|
cvss_parts.append(f"S:{cvss_breakdown['scope']}")
|
|
if cvss_breakdown.get("confidentiality"):
|
|
cvss_parts.append(f"C:{cvss_breakdown['confidentiality']}")
|
|
if cvss_breakdown.get("integrity"):
|
|
cvss_parts.append(f"I:{cvss_breakdown['integrity']}")
|
|
if cvss_breakdown.get("availability"):
|
|
cvss_parts.append(f"A:{cvss_breakdown['availability']}")
|
|
if cvss_parts:
|
|
text.append("\n\n")
|
|
text.append("CVSS Vector: ", style=self.FIELD_STYLE)
|
|
text.append("/".join(cvss_parts), style="dim")
|
|
|
|
description = vuln.get("description", "")
|
|
if description:
|
|
text.append("\n\n")
|
|
text.append("Description", style=self.FIELD_STYLE)
|
|
text.append("\n")
|
|
text.append(description)
|
|
|
|
impact = vuln.get("impact", "")
|
|
if impact:
|
|
text.append("\n\n")
|
|
text.append("Impact", style=self.FIELD_STYLE)
|
|
text.append("\n")
|
|
text.append(impact)
|
|
|
|
technical_analysis = vuln.get("technical_analysis", "")
|
|
if technical_analysis:
|
|
text.append("\n\n")
|
|
text.append("Technical Analysis", style=self.FIELD_STYLE)
|
|
text.append("\n")
|
|
text.append(technical_analysis)
|
|
|
|
poc_description = vuln.get("poc_description", "")
|
|
if poc_description:
|
|
text.append("\n\n")
|
|
text.append("PoC Description", style=self.FIELD_STYLE)
|
|
text.append("\n")
|
|
text.append(poc_description)
|
|
|
|
poc_script_code = vuln.get("poc_script_code", "")
|
|
if poc_script_code:
|
|
text.append("\n\n")
|
|
text.append("PoC Code", style=self.FIELD_STYLE)
|
|
text.append("\n")
|
|
text.append_text(self._highlight_python(poc_script_code))
|
|
|
|
remediation_steps = vuln.get("remediation_steps", "")
|
|
if remediation_steps:
|
|
text.append("\n\n")
|
|
text.append("Remediation", style=self.FIELD_STYLE)
|
|
text.append("\n")
|
|
text.append(remediation_steps)
|
|
|
|
return text
|
|
|
|
def _get_markdown_report(self) -> str: # noqa: PLR0912, PLR0915
|
|
"""Get Markdown version of vulnerability report for clipboard."""
|
|
vuln = self.vulnerability
|
|
lines: list[str] = []
|
|
|
|
# Title
|
|
title = vuln.get("title", "Untitled Vulnerability")
|
|
lines.append(f"# {title}")
|
|
lines.append("")
|
|
|
|
# Metadata
|
|
if vuln.get("id"):
|
|
lines.append(f"**ID:** {vuln['id']}")
|
|
if vuln.get("severity"):
|
|
lines.append(f"**Severity:** {vuln['severity'].upper()}")
|
|
if vuln.get("timestamp"):
|
|
lines.append(f"**Found:** {vuln['timestamp']}")
|
|
if vuln.get("agent_name"):
|
|
lines.append(f"**Agent:** {vuln['agent_name']}")
|
|
if vuln.get("target"):
|
|
lines.append(f"**Target:** {vuln['target']}")
|
|
if vuln.get("endpoint"):
|
|
lines.append(f"**Endpoint:** {vuln['endpoint']}")
|
|
if vuln.get("method"):
|
|
lines.append(f"**Method:** {vuln['method']}")
|
|
if vuln.get("cve"):
|
|
lines.append(f"**CVE:** {vuln['cve']}")
|
|
if vuln.get("cvss") is not None:
|
|
lines.append(f"**CVSS:** {vuln['cvss']}")
|
|
|
|
# CVSS Vector
|
|
cvss_breakdown = vuln.get("cvss_breakdown", {})
|
|
if cvss_breakdown:
|
|
abbrevs = {
|
|
"attack_vector": "AV",
|
|
"attack_complexity": "AC",
|
|
"privileges_required": "PR",
|
|
"user_interaction": "UI",
|
|
"scope": "S",
|
|
"confidentiality": "C",
|
|
"integrity": "I",
|
|
"availability": "A",
|
|
}
|
|
parts = [
|
|
f"{abbrevs.get(k, k)}:{v}" for k, v in cvss_breakdown.items() if v and k in abbrevs
|
|
]
|
|
if parts:
|
|
lines.append(f"**CVSS Vector:** {'/'.join(parts)}")
|
|
|
|
# Description
|
|
lines.append("")
|
|
lines.append("## Description")
|
|
lines.append("")
|
|
lines.append(vuln.get("description") or "No description provided.")
|
|
|
|
# Impact
|
|
if vuln.get("impact"):
|
|
lines.extend(["", "## Impact", "", vuln["impact"]])
|
|
|
|
# Technical Analysis
|
|
if vuln.get("technical_analysis"):
|
|
lines.extend(["", "## Technical Analysis", "", vuln["technical_analysis"]])
|
|
|
|
# Proof of Concept
|
|
if vuln.get("poc_description") or vuln.get("poc_script_code"):
|
|
lines.extend(["", "## Proof of Concept", ""])
|
|
if vuln.get("poc_description"):
|
|
lines.append(vuln["poc_description"])
|
|
lines.append("")
|
|
if vuln.get("poc_script_code"):
|
|
lines.append("```python")
|
|
lines.append(vuln["poc_script_code"])
|
|
lines.append("```")
|
|
|
|
# Code Analysis
|
|
if vuln.get("code_locations"):
|
|
lines.extend(["", "## Code Analysis", ""])
|
|
for i, loc in enumerate(vuln["code_locations"]):
|
|
file_ref = loc.get("file", "unknown")
|
|
line_ref = ""
|
|
if loc.get("start_line") is not None:
|
|
if loc.get("end_line") and loc["end_line"] != loc["start_line"]:
|
|
line_ref = f" (lines {loc['start_line']}-{loc['end_line']})"
|
|
else:
|
|
line_ref = f" (line {loc['start_line']})"
|
|
lines.append(f"**Location {i + 1}:** `{file_ref}`{line_ref}")
|
|
if loc.get("label"):
|
|
lines.append(f" {loc['label']}")
|
|
if loc.get("snippet"):
|
|
lines.append(f"```\n{loc['snippet']}\n```")
|
|
if loc.get("fix_before") or loc.get("fix_after"):
|
|
lines.append("**Suggested Fix:**")
|
|
lines.append("```diff")
|
|
if loc.get("fix_before"):
|
|
lines.extend(f"- {line}" for line in loc["fix_before"].splitlines())
|
|
if loc.get("fix_after"):
|
|
lines.extend(f"+ {line}" for line in loc["fix_after"].splitlines())
|
|
lines.append("```")
|
|
lines.append("")
|
|
|
|
# Remediation
|
|
if vuln.get("remediation_steps"):
|
|
lines.extend(["", "## Remediation", "", vuln["remediation_steps"]])
|
|
|
|
lines.append("")
|
|
return "\n".join(lines)
|
|
|
|
def on_key(self, event: events.Key) -> None:
|
|
if event.key == "escape":
|
|
self.app.pop_screen()
|
|
event.prevent_default()
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
if event.button.id == "copy_vuln_detail":
|
|
markdown_text = self._get_markdown_report()
|
|
self.app.copy_to_clipboard(markdown_text)
|
|
|
|
copy_button = self.query_one("#copy_vuln_detail", Button)
|
|
copy_button.label = "Copied!"
|
|
self.set_timer(1.5, lambda: setattr(copy_button, "label", "Copy"))
|
|
elif event.button.id == "close_vuln_detail":
|
|
self.app.pop_screen()
|
|
|
|
|
|
class VulnerabilityItem(Static): # type: ignore[misc]
|
|
"""A clickable vulnerability item."""
|
|
|
|
def __init__(self, label: Text, vuln_data: dict[str, Any], **kwargs: Any) -> None:
|
|
super().__init__(label, **kwargs)
|
|
self.vuln_data = vuln_data
|
|
|
|
def on_click(self, _event: events.Click) -> None:
|
|
"""Handle click to open vulnerability detail."""
|
|
self.app.push_screen(VulnerabilityDetailScreen(self.vuln_data))
|
|
|
|
|
|
class VulnerabilitiesPanel(VerticalScroll): # type: ignore[misc]
|
|
"""A scrollable panel showing found vulnerabilities with severity-colored dots."""
|
|
|
|
SEVERITY_COLORS: ClassVar[dict[str, str]] = {
|
|
"critical": "#dc2626", # Red
|
|
"high": "#ea580c", # Orange
|
|
"medium": "#d97706", # Amber
|
|
"low": "#22c55e", # Green
|
|
"info": "#3b82f6", # Blue
|
|
}
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self._vulnerabilities: list[dict[str, Any]] = []
|
|
|
|
def compose(self) -> ComposeResult:
|
|
return []
|
|
|
|
def update_vulnerabilities(self, vulnerabilities: list[dict[str, Any]]) -> None:
|
|
"""Update the list of vulnerabilities and re-render."""
|
|
if self._vulnerabilities == vulnerabilities:
|
|
return
|
|
self._vulnerabilities = list(vulnerabilities)
|
|
self._render_panel()
|
|
|
|
def _render_panel(self) -> None:
|
|
"""Render the vulnerabilities panel content."""
|
|
for child in list(self.children):
|
|
if isinstance(child, VulnerabilityItem):
|
|
child.remove()
|
|
|
|
if not self._vulnerabilities:
|
|
return
|
|
|
|
for vuln in self._vulnerabilities:
|
|
severity = vuln.get("severity", "info").lower()
|
|
title = vuln.get("title", "Unknown Vulnerability")
|
|
color = self.SEVERITY_COLORS.get(severity, "#3b82f6")
|
|
|
|
label = Text()
|
|
label.append("● ", style=Style(color=color))
|
|
label.append(title, style=Style(color="#d4d4d4"))
|
|
|
|
item = VulnerabilityItem(label, vuln, classes="vuln-item")
|
|
self.mount(item)
|
|
|
|
|
|
class QuitScreen(ModalScreen): # type: ignore[misc]
|
|
def compose(self) -> ComposeResult:
|
|
yield Grid(
|
|
Label("Quit Strix?", id="quit_title"),
|
|
Grid(
|
|
Button("Yes", variant="error", id="quit"),
|
|
Button("No", variant="default", id="cancel"),
|
|
id="quit_buttons",
|
|
),
|
|
id="quit_dialog",
|
|
)
|
|
|
|
def on_mount(self) -> None:
|
|
cancel_button = self.query_one("#cancel", Button)
|
|
cancel_button.focus()
|
|
|
|
def on_key(self, event: events.Key) -> None:
|
|
if event.key in ("left", "right", "up", "down"):
|
|
focused = self.focused
|
|
|
|
if focused and focused.id == "quit":
|
|
cancel_button = self.query_one("#cancel", Button)
|
|
cancel_button.focus()
|
|
else:
|
|
quit_button = self.query_one("#quit", Button)
|
|
quit_button.focus()
|
|
|
|
event.prevent_default()
|
|
elif event.key == "enter":
|
|
focused = self.focused
|
|
if focused and isinstance(focused, Button):
|
|
focused.press()
|
|
event.prevent_default()
|
|
elif event.key == "escape":
|
|
self.app.pop_screen()
|
|
event.prevent_default()
|
|
|
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
if event.button.id == "quit":
|
|
self.app.action_custom_quit()
|
|
else:
|
|
self.app.pop_screen()
|
|
|
|
|
|
class StrixTUIApp(App): # type: ignore[misc]
|
|
CSS_PATH = "assets/tui_styles.tcss"
|
|
ALLOW_SELECT = True
|
|
|
|
SIDEBAR_MIN_WIDTH = 120
|
|
|
|
selected_agent_id: reactive[str | None] = reactive(default=None)
|
|
show_splash: reactive[bool] = reactive(default=True)
|
|
|
|
BINDINGS: ClassVar[list[Binding]] = [
|
|
Binding("f1", "toggle_help", "Help", priority=True),
|
|
Binding("ctrl+q", "request_quit", "Quit", priority=True),
|
|
Binding("ctrl+c", "request_quit", "Quit", priority=True),
|
|
Binding("escape", "stop_selected_agent", "Stop Agent", priority=True),
|
|
]
|
|
|
|
def __init__(self, args: argparse.Namespace):
|
|
super().__init__()
|
|
self.args = args
|
|
self.scan_config = self._build_scan_config(args)
|
|
self.agent_config = self._build_agent_config(args)
|
|
|
|
self.tracer = Tracer(self.scan_config["run_name"])
|
|
self.tracer.set_scan_config(self.scan_config)
|
|
set_global_tracer(self.tracer)
|
|
|
|
self.agent_nodes: dict[str, TreeNode] = {}
|
|
|
|
self._displayed_agents: set[str] = set()
|
|
self._displayed_events: list[str] = []
|
|
|
|
self._streaming_render_cache: dict[str, tuple[int, Any]] = {}
|
|
self._last_streaming_len: dict[str, int] = {}
|
|
|
|
self._scan_thread: threading.Thread | None = None
|
|
self._scan_stop_event = threading.Event()
|
|
self._scan_completed = threading.Event()
|
|
|
|
self._spinner_frame_index: int = 0 # Current animation frame index
|
|
self._sweep_num_squares: int = 6 # Number of squares in sweep animation
|
|
self._sweep_colors: list[str] = [
|
|
"#000000", # Dimmest (shows dot)
|
|
"#031a09",
|
|
"#052e16",
|
|
"#0d4a2a",
|
|
"#15803d",
|
|
"#22c55e",
|
|
"#4ade80",
|
|
"#86efac", # Brightest
|
|
]
|
|
self._dot_animation_timer: Any | None = None
|
|
|
|
self._setup_cleanup_handlers()
|
|
|
|
def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]:
|
|
return {
|
|
"scan_id": args.run_name,
|
|
"targets": args.targets_info,
|
|
"user_instructions": args.instruction or "",
|
|
"run_name": args.run_name,
|
|
"diff_scope": getattr(args, "diff_scope", {"active": False}),
|
|
}
|
|
|
|
def _build_agent_config(self, args: argparse.Namespace) -> dict[str, Any]:
|
|
scan_mode = getattr(args, "scan_mode", "deep")
|
|
llm_config = LLMConfig(
|
|
scan_mode=scan_mode,
|
|
interactive=True,
|
|
is_whitebox=bool(getattr(args, "local_sources", [])),
|
|
)
|
|
|
|
config = {
|
|
"llm_config": llm_config,
|
|
"max_iterations": 300,
|
|
}
|
|
|
|
if getattr(args, "local_sources", None):
|
|
config["local_sources"] = args.local_sources
|
|
|
|
return config
|
|
|
|
def _setup_cleanup_handlers(self) -> None:
|
|
def cleanup_on_exit() -> None:
|
|
from strix.runtime import cleanup_runtime
|
|
|
|
self.tracer.cleanup()
|
|
cleanup_runtime()
|
|
|
|
def signal_handler(_signum: int, _frame: Any) -> None:
|
|
self.tracer.cleanup()
|
|
sys.exit(0)
|
|
|
|
atexit.register(cleanup_on_exit)
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
if hasattr(signal, "SIGHUP"):
|
|
signal.signal(signal.SIGHUP, signal_handler)
|
|
|
|
def compose(self) -> ComposeResult:
|
|
if self.show_splash:
|
|
yield SplashScreen(id="splash_screen")
|
|
|
|
def watch_show_splash(self, show_splash: bool) -> None:
|
|
if not show_splash and self.is_mounted:
|
|
try:
|
|
splash = self.query_one("#splash_screen")
|
|
splash.remove()
|
|
except ValueError:
|
|
pass
|
|
|
|
main_container = Vertical(id="main_container")
|
|
|
|
self.mount(main_container)
|
|
|
|
content_container = Horizontal(id="content_container")
|
|
main_container.mount(content_container)
|
|
|
|
chat_area_container = Vertical(id="chat_area_container")
|
|
|
|
chat_display = Static("", id="chat_display")
|
|
chat_history = VerticalScroll(chat_display, id="chat_history")
|
|
chat_history.can_focus = True
|
|
|
|
status_text = Static("", id="status_text")
|
|
status_text.ALLOW_SELECT = False
|
|
keymap_indicator = Static("", id="keymap_indicator")
|
|
keymap_indicator.ALLOW_SELECT = False
|
|
|
|
agent_status_display = Horizontal(
|
|
status_text, keymap_indicator, id="agent_status_display", classes="hidden"
|
|
)
|
|
|
|
chat_prompt = Static("> ", id="chat_prompt")
|
|
chat_prompt.ALLOW_SELECT = False
|
|
chat_input = ChatTextArea(
|
|
"",
|
|
id="chat_input",
|
|
show_line_numbers=False,
|
|
)
|
|
chat_input.set_app_reference(self)
|
|
chat_input_container = Horizontal(chat_prompt, chat_input, id="chat_input_container")
|
|
|
|
agents_tree = Tree("Agents", id="agents_tree")
|
|
agents_tree.root.expand()
|
|
agents_tree.show_root = False
|
|
|
|
agents_tree.show_guide = True
|
|
agents_tree.guide_depth = 3
|
|
agents_tree.guide_style = "dashed"
|
|
|
|
stats_display = Static("", id="stats_display")
|
|
stats_scroll = VerticalScroll(stats_display, id="stats_scroll")
|
|
|
|
vulnerabilities_panel = VulnerabilitiesPanel(id="vulnerabilities_panel")
|
|
|
|
sidebar = Vertical(agents_tree, vulnerabilities_panel, stats_scroll, id="sidebar")
|
|
|
|
content_container.mount(chat_area_container)
|
|
content_container.mount(sidebar)
|
|
|
|
chat_area_container.mount(chat_history)
|
|
chat_area_container.mount(agent_status_display)
|
|
chat_area_container.mount(chat_input_container)
|
|
|
|
self.call_after_refresh(self._focus_chat_input)
|
|
|
|
def _focus_chat_input(self) -> None:
|
|
if len(self.screen_stack) > 1 or self.show_splash:
|
|
return
|
|
|
|
if not self.is_mounted:
|
|
return
|
|
|
|
try:
|
|
chat_input = self.query_one("#chat_input", ChatTextArea)
|
|
chat_input.show_vertical_scrollbar = False
|
|
chat_input.show_horizontal_scrollbar = False
|
|
chat_input.focus()
|
|
except (ValueError, Exception):
|
|
self.call_after_refresh(self._focus_chat_input)
|
|
|
|
def _focus_agents_tree(self) -> None:
|
|
if len(self.screen_stack) > 1 or self.show_splash:
|
|
return
|
|
|
|
if not self.is_mounted:
|
|
return
|
|
|
|
try:
|
|
agents_tree = self.query_one("#agents_tree", Tree)
|
|
agents_tree.focus()
|
|
|
|
if agents_tree.root.children:
|
|
first_node = agents_tree.root.children[0]
|
|
agents_tree.select_node(first_node)
|
|
except (ValueError, Exception):
|
|
self.call_after_refresh(self._focus_agents_tree)
|
|
|
|
def on_mount(self) -> None:
|
|
self.title = "strix"
|
|
|
|
self.set_timer(4.5, self._hide_splash_screen)
|
|
|
|
def _hide_splash_screen(self) -> None:
|
|
self.show_splash = False
|
|
|
|
self._start_scan_thread()
|
|
|
|
self.set_interval(0.35, self._update_ui_from_tracer)
|
|
|
|
def _update_ui_from_tracer(self) -> None:
|
|
if self.show_splash:
|
|
return
|
|
|
|
if len(self.screen_stack) > 1:
|
|
return
|
|
|
|
if not self.is_mounted:
|
|
return
|
|
|
|
try:
|
|
chat_history = self.query_one("#chat_history", VerticalScroll)
|
|
agents_tree = self.query_one("#agents_tree", Tree)
|
|
|
|
if not self._is_widget_safe(chat_history) or not self._is_widget_safe(agents_tree):
|
|
return
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
agent_updates = False
|
|
for agent_id, agent_data in list(self.tracer.agents.items()):
|
|
if agent_id not in self._displayed_agents:
|
|
self._add_agent_node(agent_data)
|
|
self._displayed_agents.add(agent_id)
|
|
agent_updates = True
|
|
elif self._update_agent_node(agent_id, agent_data):
|
|
agent_updates = True
|
|
|
|
if agent_updates:
|
|
self._expand_new_agent_nodes()
|
|
|
|
self._update_chat_view()
|
|
|
|
self._update_agent_status_display()
|
|
|
|
self._update_stats_display()
|
|
|
|
self._update_vulnerabilities_panel()
|
|
|
|
def _update_agent_node(self, agent_id: str, agent_data: dict[str, Any]) -> bool:
|
|
if agent_id not in self.agent_nodes:
|
|
return False
|
|
|
|
try:
|
|
agent_node = self.agent_nodes[agent_id]
|
|
agent_name_raw = agent_data.get("name", "Agent")
|
|
status = agent_data.get("status", "running")
|
|
|
|
status_indicators = {
|
|
"running": "⚪",
|
|
"waiting": "⏸",
|
|
"completed": "🟢",
|
|
"failed": "🔴",
|
|
"stopped": "■",
|
|
"stopping": "○",
|
|
"llm_failed": "🔴",
|
|
}
|
|
|
|
status_icon = status_indicators.get(status, "○")
|
|
vuln_count = self._agent_vulnerability_count(agent_id)
|
|
vuln_indicator = f" ({vuln_count})" if vuln_count > 0 else ""
|
|
agent_name = f"{status_icon} {agent_name_raw}{vuln_indicator}"
|
|
|
|
if agent_node.label != agent_name:
|
|
agent_node.set_label(agent_name)
|
|
return True
|
|
|
|
except (KeyError, AttributeError, ValueError) as e:
|
|
import logging
|
|
|
|
logging.warning(f"Failed to update agent node label: {e}")
|
|
|
|
return False
|
|
|
|
def _get_chat_content(
|
|
self,
|
|
) -> tuple[Any, str | None]:
|
|
if not self.selected_agent_id:
|
|
return self._get_chat_placeholder_content(
|
|
"Select an agent from the tree to see its activity.", "placeholder-no-agent"
|
|
)
|
|
|
|
events = self._gather_agent_events(self.selected_agent_id)
|
|
streaming = self.tracer.get_streaming_content(self.selected_agent_id)
|
|
|
|
if not events and not streaming:
|
|
return self._get_chat_placeholder_content(
|
|
"Starting agent...", "placeholder-no-activity"
|
|
)
|
|
|
|
current_event_ids = [e["id"] for e in events]
|
|
current_streaming_len = len(streaming) if streaming else 0
|
|
last_streaming_len = self._last_streaming_len.get(self.selected_agent_id, 0)
|
|
|
|
if (
|
|
current_event_ids == self._displayed_events
|
|
and current_streaming_len == last_streaming_len
|
|
):
|
|
return None, None
|
|
|
|
self._displayed_events = current_event_ids
|
|
self._last_streaming_len[self.selected_agent_id] = current_streaming_len
|
|
return self._get_rendered_events_content(events), "chat-content"
|
|
|
|
def _update_chat_view(self) -> None:
|
|
if len(self.screen_stack) > 1 or self.show_splash or not self.is_mounted:
|
|
return
|
|
|
|
try:
|
|
chat_history = self.query_one("#chat_history", VerticalScroll)
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
if not self._is_widget_safe(chat_history):
|
|
return
|
|
|
|
try:
|
|
is_at_bottom = chat_history.scroll_y >= chat_history.max_scroll_y
|
|
except (AttributeError, ValueError):
|
|
is_at_bottom = True
|
|
|
|
content, css_class = self._get_chat_content()
|
|
if content is None:
|
|
return
|
|
|
|
chat_display = self.query_one("#chat_display", Static)
|
|
self._safe_widget_operation(chat_display.update, content)
|
|
chat_display.set_classes(css_class)
|
|
|
|
if is_at_bottom:
|
|
self.call_later(chat_history.scroll_end, animate=False)
|
|
|
|
def _get_chat_placeholder_content(
|
|
self, message: str, placeholder_class: str
|
|
) -> tuple[Text, str]:
|
|
self._displayed_events = [placeholder_class]
|
|
text = Text()
|
|
text.append(message)
|
|
return text, f"chat-placeholder {placeholder_class}"
|
|
|
|
@staticmethod
|
|
def _merge_renderables(renderables: list[Any]) -> Text:
|
|
"""Merge renderables into a single Text for mouse text selection support."""
|
|
combined = Text()
|
|
for i, item in enumerate(renderables):
|
|
if i > 0:
|
|
combined.append("\n")
|
|
StrixTUIApp._append_renderable(combined, item)
|
|
return StrixTUIApp._sanitize_text(combined)
|
|
|
|
@staticmethod
|
|
def _sanitize_text(text: Text) -> Text:
|
|
"""Clamp spans so Rich/Textual can't crash on malformed offsets."""
|
|
plain = text.plain
|
|
text_length = len(plain)
|
|
sanitized_spans: list[Span] = []
|
|
|
|
for span in text.spans:
|
|
start = max(0, min(span.start, text_length))
|
|
end = max(0, min(span.end, text_length))
|
|
if end > start:
|
|
sanitized_spans.append(Span(start, end, span.style))
|
|
|
|
return Text(
|
|
plain,
|
|
style=text.style,
|
|
justify=text.justify,
|
|
overflow=text.overflow,
|
|
no_wrap=text.no_wrap,
|
|
end=text.end,
|
|
tab_size=text.tab_size,
|
|
spans=sanitized_spans,
|
|
)
|
|
|
|
@staticmethod
|
|
def _append_renderable(combined: Text, item: Any) -> None:
|
|
"""Recursively append a renderable's text content to a combined Text."""
|
|
if isinstance(item, Text):
|
|
combined.append_text(StrixTUIApp._sanitize_text(item))
|
|
elif isinstance(item, Group):
|
|
for j, sub in enumerate(item.renderables):
|
|
if j > 0:
|
|
combined.append("\n")
|
|
StrixTUIApp._append_renderable(combined, sub)
|
|
else:
|
|
inner = getattr(item, "content", None) or getattr(item, "renderable", None)
|
|
if inner is not None:
|
|
StrixTUIApp._append_renderable(combined, inner)
|
|
else:
|
|
combined.append(str(item))
|
|
|
|
def _get_rendered_events_content(self, events: list[dict[str, Any]]) -> Any:
|
|
renderables: list[Any] = []
|
|
|
|
if not events:
|
|
return Text()
|
|
|
|
for event in events:
|
|
content: Any = None
|
|
|
|
if event["type"] == "chat":
|
|
content = self._render_chat_content(event["data"])
|
|
elif event["type"] == "tool":
|
|
content = self._render_tool_content_simple(event["data"])
|
|
|
|
if content:
|
|
if renderables:
|
|
renderables.append(Text(""))
|
|
renderables.append(content)
|
|
|
|
if self.selected_agent_id:
|
|
streaming = self.tracer.get_streaming_content(self.selected_agent_id)
|
|
if streaming:
|
|
streaming_text = self._render_streaming_content(streaming)
|
|
if streaming_text:
|
|
if renderables:
|
|
renderables.append(Text(""))
|
|
renderables.append(streaming_text)
|
|
|
|
if not renderables:
|
|
return Text()
|
|
|
|
if len(renderables) == 1 and isinstance(renderables[0], Text):
|
|
return self._sanitize_text(renderables[0])
|
|
|
|
return self._merge_renderables(renderables)
|
|
|
|
def _render_streaming_content(self, content: str, agent_id: str | None = None) -> Any:
|
|
cache_key = agent_id or self.selected_agent_id or ""
|
|
content_len = len(content)
|
|
|
|
if cache_key in self._streaming_render_cache:
|
|
cached_len, cached_output = self._streaming_render_cache[cache_key]
|
|
if cached_len == content_len:
|
|
return cached_output
|
|
|
|
renderables: list[Any] = []
|
|
segments = parse_streaming_content(content)
|
|
|
|
for segment in segments:
|
|
if segment.type == "text":
|
|
text_content = AgentMessageRenderer.render_simple(segment.content)
|
|
if renderables:
|
|
renderables.append(Text(""))
|
|
renderables.append(text_content)
|
|
|
|
elif segment.type == "tool":
|
|
tool_renderable = self._render_streaming_tool(
|
|
segment.tool_name or "unknown",
|
|
segment.args or {},
|
|
segment.is_complete,
|
|
)
|
|
if renderables:
|
|
renderables.append(Text(""))
|
|
renderables.append(tool_renderable)
|
|
|
|
if not renderables:
|
|
result = Text()
|
|
elif len(renderables) == 1 and isinstance(renderables[0], Text):
|
|
result = self._sanitize_text(renderables[0])
|
|
else:
|
|
result = self._merge_renderables(renderables)
|
|
|
|
self._streaming_render_cache[cache_key] = (content_len, result)
|
|
return result
|
|
|
|
def _render_streaming_tool(
|
|
self, tool_name: str, args: dict[str, str], is_complete: bool
|
|
) -> Any:
|
|
tool_data = {
|
|
"tool_name": tool_name,
|
|
"args": args,
|
|
"status": "completed" if is_complete else "running",
|
|
"result": None,
|
|
}
|
|
|
|
renderer = get_tool_renderer(tool_name)
|
|
if renderer:
|
|
widget = renderer.render(tool_data)
|
|
return widget.content
|
|
|
|
return self._render_default_streaming_tool(tool_name, args, is_complete)
|
|
|
|
def _render_default_streaming_tool(
|
|
self, tool_name: str, args: dict[str, str], is_complete: bool
|
|
) -> Text:
|
|
text = Text()
|
|
|
|
if is_complete:
|
|
text.append("✓ ", style="green")
|
|
else:
|
|
text.append("● ", style="yellow")
|
|
|
|
text.append("Using tool ", style="dim")
|
|
text.append(tool_name, style="bold blue")
|
|
|
|
if args:
|
|
for key, value in list(args.items())[:3]:
|
|
text.append("\n ")
|
|
text.append(key, style="dim")
|
|
text.append(": ")
|
|
display_value = value if len(value) <= 100 else value[:97] + "..."
|
|
text.append(display_value, style="italic" if not is_complete else None)
|
|
|
|
return text
|
|
|
|
def _get_status_display_content(
|
|
self, agent_id: str, agent_data: dict[str, Any]
|
|
) -> tuple[Text | None, Text, bool]:
|
|
status = agent_data.get("status", "running")
|
|
|
|
def keymap_styled(keys: list[tuple[str, str]]) -> Text:
|
|
t = Text()
|
|
for i, (key, action) in enumerate(keys):
|
|
if i > 0:
|
|
t.append(" · ", style="dim")
|
|
t.append(key, style="white")
|
|
t.append(" ", style="dim")
|
|
t.append(action, style="dim")
|
|
return t
|
|
|
|
simple_statuses: dict[str, tuple[str, str]] = {
|
|
"stopping": ("Agent stopping...", ""),
|
|
"stopped": ("Agent stopped", ""),
|
|
"completed": ("Agent completed", ""),
|
|
}
|
|
|
|
if status in simple_statuses:
|
|
msg, _ = simple_statuses[status]
|
|
text = Text()
|
|
text.append(msg)
|
|
return (text, Text(), False)
|
|
|
|
if status == "llm_failed":
|
|
error_msg = agent_data.get("error_message", "")
|
|
text = Text()
|
|
if error_msg:
|
|
text.append(error_msg, style="red")
|
|
else:
|
|
text.append("LLM request failed", style="red")
|
|
self._stop_dot_animation()
|
|
keymap = Text()
|
|
keymap.append("Send message to retry", style="dim")
|
|
return (text, keymap, False)
|
|
|
|
if status == "waiting":
|
|
keymap = Text()
|
|
keymap.append("Send message to resume", style="dim")
|
|
return (Text(" "), keymap, False)
|
|
|
|
if status == "running":
|
|
if self._agent_has_real_activity(agent_id):
|
|
animated_text = Text()
|
|
animated_text.append_text(self._get_sweep_animation(self._sweep_colors))
|
|
animated_text.append("esc", style="white")
|
|
animated_text.append(" ", style="dim")
|
|
animated_text.append("stop", style="dim")
|
|
return (animated_text, keymap_styled([("ctrl-q", "quit")]), True)
|
|
animated_text = self._get_animated_verb_text(agent_id, "Initializing")
|
|
return (animated_text, keymap_styled([("ctrl-q", "quit")]), True)
|
|
|
|
return (None, Text(), False)
|
|
|
|
def _update_agent_status_display(self) -> None:
|
|
try:
|
|
status_display = self.query_one("#agent_status_display", Horizontal)
|
|
status_text = self.query_one("#status_text", Static)
|
|
keymap_indicator = self.query_one("#keymap_indicator", Static)
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
widgets = [status_display, status_text, keymap_indicator]
|
|
if not all(self._is_widget_safe(w) for w in widgets):
|
|
return
|
|
|
|
if not self.selected_agent_id:
|
|
self._safe_widget_operation(status_display.add_class, "hidden")
|
|
return
|
|
|
|
try:
|
|
agent_data = self.tracer.agents[self.selected_agent_id]
|
|
content, keymap, should_animate = self._get_status_display_content(
|
|
self.selected_agent_id, agent_data
|
|
)
|
|
|
|
if not content:
|
|
self._safe_widget_operation(status_display.add_class, "hidden")
|
|
return
|
|
|
|
self._safe_widget_operation(status_text.update, content)
|
|
self._safe_widget_operation(keymap_indicator.update, keymap)
|
|
self._safe_widget_operation(status_display.remove_class, "hidden")
|
|
|
|
if should_animate:
|
|
self._start_dot_animation()
|
|
|
|
except (KeyError, Exception):
|
|
self._safe_widget_operation(status_display.add_class, "hidden")
|
|
|
|
def _update_stats_display(self) -> None:
|
|
try:
|
|
stats_display = self.query_one("#stats_display", Static)
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
if not self._is_widget_safe(stats_display):
|
|
return
|
|
|
|
if self.screen.selections:
|
|
return
|
|
|
|
stats_content = Text()
|
|
|
|
stats_text = build_tui_stats_text(self.tracer, self.agent_config)
|
|
if stats_text:
|
|
stats_content.append(stats_text)
|
|
|
|
version = get_package_version()
|
|
stats_content.append(f"\nv{version}", style="white")
|
|
|
|
self._safe_widget_operation(stats_display.update, stats_content)
|
|
|
|
def _update_vulnerabilities_panel(self) -> None:
|
|
"""Update the vulnerabilities panel with current vulnerability data."""
|
|
try:
|
|
vuln_panel = self.query_one("#vulnerabilities_panel", VulnerabilitiesPanel)
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
if not self._is_widget_safe(vuln_panel):
|
|
return
|
|
|
|
vulnerabilities = self.tracer.vulnerability_reports
|
|
|
|
if not vulnerabilities:
|
|
self._safe_widget_operation(vuln_panel.add_class, "hidden")
|
|
return
|
|
|
|
enriched_vulns = []
|
|
for vuln in vulnerabilities:
|
|
enriched = dict(vuln)
|
|
report_id = vuln.get("id", "")
|
|
agent_name = self._get_agent_name_for_vulnerability(report_id)
|
|
if agent_name:
|
|
enriched["agent_name"] = agent_name
|
|
enriched_vulns.append(enriched)
|
|
|
|
self._safe_widget_operation(vuln_panel.remove_class, "hidden")
|
|
vuln_panel.update_vulnerabilities(enriched_vulns)
|
|
|
|
def _get_agent_name_for_vulnerability(self, report_id: str) -> str | None:
|
|
"""Find the agent name that created a vulnerability report."""
|
|
for _exec_id, tool_data in list(self.tracer.tool_executions.items()):
|
|
if tool_data.get("tool_name") == "create_vulnerability_report":
|
|
result = tool_data.get("result", {})
|
|
if isinstance(result, dict) and result.get("report_id") == report_id:
|
|
agent_id = tool_data.get("agent_id")
|
|
if agent_id and agent_id in self.tracer.agents:
|
|
name: str = self.tracer.agents[agent_id].get("name", "Unknown Agent")
|
|
return name
|
|
return None
|
|
|
|
def _get_sweep_animation(self, color_palette: list[str]) -> Text:
|
|
text = Text()
|
|
num_squares = self._sweep_num_squares
|
|
num_colors = len(color_palette)
|
|
|
|
offset = num_colors - 1
|
|
max_pos = (num_squares - 1) + offset
|
|
total_range = max_pos + offset
|
|
cycle_length = total_range * 2
|
|
frame_in_cycle = self._spinner_frame_index % cycle_length
|
|
|
|
wave_pos = total_range - abs(total_range - frame_in_cycle)
|
|
sweep_pos = wave_pos - offset
|
|
|
|
dot_color = "#0a3d1f"
|
|
|
|
for i in range(num_squares):
|
|
dist = abs(i - sweep_pos)
|
|
color_idx = max(0, num_colors - 1 - dist)
|
|
|
|
if color_idx == 0:
|
|
text.append("·", style=Style(color=dot_color))
|
|
else:
|
|
color = color_palette[color_idx]
|
|
text.append("▪", style=Style(color=color))
|
|
|
|
text.append(" ")
|
|
return text
|
|
|
|
def _get_animated_verb_text(self, agent_id: str, verb: str) -> Text: # noqa: ARG002
|
|
text = Text()
|
|
sweep = self._get_sweep_animation(self._sweep_colors)
|
|
text.append_text(sweep)
|
|
parts = verb.split(" ", 1)
|
|
text.append(parts[0], style="white")
|
|
if len(parts) > 1:
|
|
text.append(" ", style="dim")
|
|
text.append(parts[1], style="dim")
|
|
return text
|
|
|
|
def _start_dot_animation(self) -> None:
|
|
if self._dot_animation_timer is None:
|
|
self._dot_animation_timer = self.set_interval(0.06, self._animate_dots)
|
|
|
|
def _stop_dot_animation(self) -> None:
|
|
if self._dot_animation_timer is not None:
|
|
self._dot_animation_timer.stop()
|
|
self._dot_animation_timer = None
|
|
|
|
def _animate_dots(self) -> None:
|
|
has_active_agents = False
|
|
|
|
if self.selected_agent_id and self.selected_agent_id in self.tracer.agents:
|
|
agent_data = self.tracer.agents[self.selected_agent_id]
|
|
status = agent_data.get("status", "running")
|
|
if status in ["running", "waiting"]:
|
|
has_active_agents = True
|
|
num_colors = len(self._sweep_colors)
|
|
offset = num_colors - 1
|
|
max_pos = (self._sweep_num_squares - 1) + offset
|
|
total_range = max_pos + offset
|
|
cycle_length = total_range * 2
|
|
self._spinner_frame_index = (self._spinner_frame_index + 1) % cycle_length
|
|
self._update_agent_status_display()
|
|
|
|
if not has_active_agents:
|
|
has_active_agents = any(
|
|
agent_data.get("status", "running") in ["running", "waiting"]
|
|
for agent_data in self.tracer.agents.values()
|
|
)
|
|
|
|
if not has_active_agents:
|
|
self._stop_dot_animation()
|
|
self._spinner_frame_index = 0
|
|
|
|
def _agent_has_real_activity(self, agent_id: str) -> bool:
|
|
initial_tools = {"scan_start_info", "subagent_start_info"}
|
|
|
|
for _exec_id, tool_data in list(self.tracer.tool_executions.items()):
|
|
if tool_data.get("agent_id") == agent_id:
|
|
tool_name = tool_data.get("tool_name", "")
|
|
if tool_name not in initial_tools:
|
|
return True
|
|
|
|
streaming = self.tracer.get_streaming_content(agent_id)
|
|
return bool(streaming and streaming.strip())
|
|
|
|
def _agent_vulnerability_count(self, agent_id: str) -> int:
|
|
count = 0
|
|
for _exec_id, tool_data in list(self.tracer.tool_executions.items()):
|
|
if tool_data.get("agent_id") == agent_id:
|
|
tool_name = tool_data.get("tool_name", "")
|
|
if tool_name == "create_vulnerability_report":
|
|
status = tool_data.get("status", "")
|
|
if status == "completed":
|
|
result = tool_data.get("result", {})
|
|
if isinstance(result, dict) and result.get("success"):
|
|
count += 1
|
|
return count
|
|
|
|
def _gather_agent_events(self, agent_id: str) -> list[dict[str, Any]]:
|
|
chat_events = [
|
|
{
|
|
"type": "chat",
|
|
"timestamp": msg["timestamp"],
|
|
"id": f"chat_{msg['message_id']}",
|
|
"data": msg,
|
|
}
|
|
for msg in self.tracer.chat_messages
|
|
if msg.get("agent_id") == agent_id
|
|
]
|
|
|
|
tool_events = [
|
|
{
|
|
"type": "tool",
|
|
"timestamp": tool_data["timestamp"],
|
|
"id": f"tool_{exec_id}",
|
|
"data": tool_data,
|
|
}
|
|
for exec_id, tool_data in list(self.tracer.tool_executions.items())
|
|
if tool_data.get("agent_id") == agent_id
|
|
]
|
|
|
|
events = chat_events + tool_events
|
|
events.sort(key=lambda e: (e["timestamp"], e["id"]))
|
|
return events
|
|
|
|
def watch_selected_agent_id(self, _agent_id: str | None) -> None:
|
|
if len(self.screen_stack) > 1 or self.show_splash:
|
|
return
|
|
|
|
if not self.is_mounted:
|
|
return
|
|
|
|
self._displayed_events.clear()
|
|
self._streaming_render_cache.clear()
|
|
self._last_streaming_len.clear()
|
|
|
|
self.call_later(self._update_chat_view)
|
|
self._update_agent_status_display()
|
|
|
|
def _start_scan_thread(self) -> None:
|
|
def scan_target() -> None:
|
|
try:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
try:
|
|
agent = StrixAgent(self.agent_config)
|
|
|
|
if not self._scan_stop_event.is_set():
|
|
loop.run_until_complete(agent.execute_scan(self.scan_config))
|
|
|
|
except (KeyboardInterrupt, asyncio.CancelledError):
|
|
logging.info("Scan interrupted by user")
|
|
except (ConnectionError, TimeoutError):
|
|
logging.exception("Network error during scan")
|
|
except RuntimeError:
|
|
logging.exception("Runtime error during scan")
|
|
except Exception:
|
|
logging.exception("Unexpected error during scan")
|
|
finally:
|
|
loop.close()
|
|
self._scan_completed.set()
|
|
|
|
except Exception:
|
|
logging.exception("Error setting up scan thread")
|
|
self._scan_completed.set()
|
|
|
|
self._scan_thread = threading.Thread(target=scan_target, daemon=True)
|
|
self._scan_thread.start()
|
|
|
|
def _add_agent_node(self, agent_data: dict[str, Any]) -> None:
|
|
if len(self.screen_stack) > 1 or self.show_splash:
|
|
return
|
|
|
|
if not self.is_mounted:
|
|
return
|
|
|
|
agent_id = agent_data["id"]
|
|
parent_id = agent_data.get("parent_id")
|
|
status = agent_data.get("status", "running")
|
|
|
|
try:
|
|
agents_tree = self.query_one("#agents_tree", Tree)
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
agent_name_raw = agent_data.get("name", "Agent")
|
|
|
|
status_indicators = {
|
|
"running": "⚪",
|
|
"waiting": "⏸",
|
|
"completed": "🟢",
|
|
"failed": "🔴",
|
|
"stopped": "■",
|
|
"stopping": "○",
|
|
"llm_failed": "🔴",
|
|
}
|
|
|
|
status_icon = status_indicators.get(status, "○")
|
|
vuln_count = self._agent_vulnerability_count(agent_id)
|
|
vuln_indicator = f" ({vuln_count})" if vuln_count > 0 else ""
|
|
agent_name = f"{status_icon} {agent_name_raw}{vuln_indicator}"
|
|
|
|
try:
|
|
if parent_id and parent_id in self.agent_nodes:
|
|
parent_node = self.agent_nodes[parent_id]
|
|
agent_node = parent_node.add(
|
|
agent_name,
|
|
data={"agent_id": agent_id},
|
|
)
|
|
parent_node.allow_expand = True
|
|
else:
|
|
agent_node = agents_tree.root.add(
|
|
agent_name,
|
|
data={"agent_id": agent_id},
|
|
)
|
|
|
|
agent_node.allow_expand = False
|
|
agent_node.expand()
|
|
self.agent_nodes[agent_id] = agent_node
|
|
|
|
if len(self.agent_nodes) == 1:
|
|
agents_tree.select_node(agent_node)
|
|
self.selected_agent_id = agent_id
|
|
|
|
self._reorganize_orphaned_agents(agent_id)
|
|
except (AttributeError, ValueError, RuntimeError) as e:
|
|
import logging
|
|
|
|
logging.warning(f"Failed to add agent node {agent_id}: {e}")
|
|
|
|
def _expand_new_agent_nodes(self) -> None:
|
|
if len(self.screen_stack) > 1 or self.show_splash:
|
|
return
|
|
|
|
if not self.is_mounted:
|
|
return
|
|
|
|
def _expand_all_agent_nodes(self) -> None:
|
|
if len(self.screen_stack) > 1 or self.show_splash:
|
|
return
|
|
|
|
if not self.is_mounted:
|
|
return
|
|
|
|
try:
|
|
agents_tree = self.query_one("#agents_tree", Tree)
|
|
self._expand_node_recursively(agents_tree.root)
|
|
except (ValueError, Exception):
|
|
logging.debug("Tree not ready for expanding nodes")
|
|
|
|
def _expand_node_recursively(self, node: TreeNode) -> None:
|
|
if not node.is_expanded:
|
|
node.expand()
|
|
for child in node.children:
|
|
self._expand_node_recursively(child)
|
|
|
|
def _copy_node_under(self, node_to_copy: TreeNode, new_parent: TreeNode) -> None:
|
|
agent_id = node_to_copy.data["agent_id"]
|
|
agent_data = self.tracer.agents.get(agent_id, {})
|
|
agent_name_raw = agent_data.get("name", "Agent")
|
|
status = agent_data.get("status", "running")
|
|
|
|
status_indicators = {
|
|
"running": "⚪",
|
|
"waiting": "⏸",
|
|
"completed": "🟢",
|
|
"failed": "🔴",
|
|
"stopped": "■",
|
|
"stopping": "○",
|
|
"llm_failed": "🔴",
|
|
}
|
|
|
|
status_icon = status_indicators.get(status, "○")
|
|
vuln_count = self._agent_vulnerability_count(agent_id)
|
|
vuln_indicator = f" ({vuln_count})" if vuln_count > 0 else ""
|
|
agent_name = f"{status_icon} {agent_name_raw}{vuln_indicator}"
|
|
|
|
new_node = new_parent.add(
|
|
agent_name,
|
|
data=node_to_copy.data,
|
|
)
|
|
new_node.allow_expand = node_to_copy.allow_expand
|
|
|
|
self.agent_nodes[agent_id] = new_node
|
|
|
|
for child in node_to_copy.children:
|
|
self._copy_node_under(child, new_node)
|
|
|
|
if node_to_copy.is_expanded:
|
|
new_node.expand()
|
|
|
|
def _reorganize_orphaned_agents(self, new_parent_id: str) -> None:
|
|
agents_to_move = []
|
|
|
|
for agent_id, agent_data in list(self.tracer.agents.items()):
|
|
if (
|
|
agent_data.get("parent_id") == new_parent_id
|
|
and agent_id in self.agent_nodes
|
|
and agent_id != new_parent_id
|
|
):
|
|
agents_to_move.append(agent_id)
|
|
|
|
if not agents_to_move:
|
|
return
|
|
|
|
parent_node = self.agent_nodes[new_parent_id]
|
|
|
|
for child_agent_id in agents_to_move:
|
|
if child_agent_id in self.agent_nodes:
|
|
old_node = self.agent_nodes[child_agent_id]
|
|
|
|
if old_node.parent is parent_node:
|
|
continue
|
|
|
|
self._copy_node_under(old_node, parent_node)
|
|
|
|
old_node.remove()
|
|
|
|
parent_node.allow_expand = True
|
|
parent_node.expand()
|
|
|
|
def _render_chat_content(self, msg_data: dict[str, Any]) -> Any:
|
|
role = msg_data.get("role")
|
|
content = msg_data.get("content", "")
|
|
metadata = msg_data.get("metadata", {})
|
|
|
|
if not content:
|
|
return None
|
|
|
|
if role == "user":
|
|
return UserMessageRenderer.render_simple(content)
|
|
|
|
if metadata.get("interrupted"):
|
|
streaming_result = self._render_streaming_content(content)
|
|
interrupted_text = Text()
|
|
interrupted_text.append("\n")
|
|
interrupted_text.append("⚠ ", style="yellow")
|
|
interrupted_text.append("Interrupted by user", style="yellow dim")
|
|
return self._merge_renderables([streaming_result, interrupted_text])
|
|
|
|
return AgentMessageRenderer.render_simple(content)
|
|
|
|
def _render_tool_content_simple(self, tool_data: dict[str, Any]) -> Any:
|
|
tool_name = tool_data.get("tool_name", "Unknown Tool")
|
|
args = tool_data.get("args", {})
|
|
status = tool_data.get("status", "unknown")
|
|
result = tool_data.get("result")
|
|
|
|
renderer = get_tool_renderer(tool_name)
|
|
|
|
if renderer:
|
|
widget = renderer.render(tool_data)
|
|
return widget.content
|
|
|
|
text = Text()
|
|
|
|
if tool_name in ("llm_error_details", "sandbox_error_details"):
|
|
return self._render_error_details(text, tool_name, args)
|
|
|
|
text.append("→ Using tool ")
|
|
text.append(tool_name, style="bold blue")
|
|
|
|
status_styles = {
|
|
"running": ("●", "yellow"),
|
|
"completed": ("✓", "green"),
|
|
"failed": ("✗", "red"),
|
|
"error": ("✗", "red"),
|
|
}
|
|
icon, style = status_styles.get(status, ("○", "dim"))
|
|
text.append(" ")
|
|
text.append(icon, style=style)
|
|
|
|
if args:
|
|
for k, v in list(args.items())[:5]:
|
|
str_v = str(v)
|
|
if len(str_v) > 500:
|
|
str_v = str_v[:497] + "..."
|
|
text.append("\n ")
|
|
text.append(k, style="dim")
|
|
text.append(": ")
|
|
text.append(str_v)
|
|
|
|
if status in ["completed", "failed", "error"] and result:
|
|
result_str = str(result)
|
|
if len(result_str) > 1000:
|
|
result_str = result_str[:997] + "..."
|
|
text.append("\n")
|
|
text.append("Result: ", style="bold")
|
|
text.append(result_str)
|
|
|
|
return text
|
|
|
|
def _render_error_details(self, text: Any, tool_name: str, args: dict[str, Any]) -> Any:
|
|
if tool_name == "llm_error_details":
|
|
text.append("✗ LLM Request Failed", style="red")
|
|
else:
|
|
text.append("✗ Sandbox Initialization Failed", style="red")
|
|
if args.get("error"):
|
|
text.append(f"\n{args['error']}", style="bold red")
|
|
if args.get("details"):
|
|
details = str(args["details"])
|
|
if len(details) > 1000:
|
|
details = details[:997] + "..."
|
|
text.append("\nDetails: ", style="dim")
|
|
text.append(details)
|
|
return text
|
|
|
|
@on(Tree.NodeHighlighted) # type: ignore[misc]
|
|
def handle_tree_highlight(self, event: Tree.NodeHighlighted) -> None:
|
|
if len(self.screen_stack) > 1 or self.show_splash:
|
|
return
|
|
|
|
if not self.is_mounted:
|
|
return
|
|
|
|
node = event.node
|
|
|
|
try:
|
|
agents_tree = self.query_one("#agents_tree", Tree)
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
if self.focused == agents_tree and node.data:
|
|
agent_id = node.data.get("agent_id")
|
|
if agent_id:
|
|
self.selected_agent_id = agent_id
|
|
|
|
@on(Tree.NodeSelected) # type: ignore[misc]
|
|
def handle_tree_node_selected(self, event: Tree.NodeSelected) -> None:
|
|
if len(self.screen_stack) > 1 or self.show_splash:
|
|
return
|
|
|
|
if not self.is_mounted:
|
|
return
|
|
|
|
node = event.node
|
|
|
|
if node.allow_expand:
|
|
if node.is_expanded:
|
|
node.collapse()
|
|
else:
|
|
node.expand()
|
|
|
|
def _send_user_message(self, message: str) -> None:
|
|
if not self.selected_agent_id:
|
|
return
|
|
|
|
if self.tracer:
|
|
streaming_content = self.tracer.get_streaming_content(self.selected_agent_id)
|
|
if streaming_content and streaming_content.strip():
|
|
self.tracer.clear_streaming_content(self.selected_agent_id)
|
|
self.tracer.interrupted_content[self.selected_agent_id] = streaming_content
|
|
self.tracer.log_chat_message(
|
|
content=streaming_content,
|
|
role="assistant",
|
|
agent_id=self.selected_agent_id,
|
|
metadata={"interrupted": True},
|
|
)
|
|
|
|
try:
|
|
from strix.tools.agents_graph.agents_graph_actions import _agent_instances
|
|
|
|
if self.selected_agent_id in _agent_instances:
|
|
agent_instance = _agent_instances[self.selected_agent_id]
|
|
if hasattr(agent_instance, "cancel_current_execution"):
|
|
agent_instance.cancel_current_execution()
|
|
except (ImportError, AttributeError, KeyError):
|
|
pass
|
|
|
|
if self.tracer:
|
|
self.tracer.log_chat_message(
|
|
content=message,
|
|
role="user",
|
|
agent_id=self.selected_agent_id,
|
|
)
|
|
|
|
try:
|
|
from strix.tools.agents_graph.agents_graph_actions import send_user_message_to_agent
|
|
|
|
send_user_message_to_agent(self.selected_agent_id, message)
|
|
|
|
except (ImportError, AttributeError) as e:
|
|
import logging
|
|
|
|
logging.warning(f"Failed to send message to agent {self.selected_agent_id}: {e}")
|
|
|
|
self._displayed_events.clear()
|
|
self._update_chat_view()
|
|
|
|
self.call_after_refresh(self._focus_chat_input)
|
|
|
|
def _get_agent_name(self, agent_id: str) -> str:
|
|
try:
|
|
if self.tracer and agent_id in self.tracer.agents:
|
|
agent_name = self.tracer.agents[agent_id].get("name")
|
|
if isinstance(agent_name, str):
|
|
return agent_name
|
|
except (KeyError, AttributeError) as e:
|
|
logging.warning(f"Could not retrieve agent name for {agent_id}: {e}")
|
|
return "Unknown Agent"
|
|
|
|
def action_toggle_help(self) -> None:
|
|
if self.show_splash or not self.is_mounted:
|
|
return
|
|
|
|
try:
|
|
self.query_one("#main_container")
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
if isinstance(self.screen, HelpScreen):
|
|
self.pop_screen()
|
|
return
|
|
|
|
if len(self.screen_stack) > 1:
|
|
return
|
|
|
|
self.push_screen(HelpScreen())
|
|
|
|
def action_request_quit(self) -> None:
|
|
if self.show_splash or not self.is_mounted:
|
|
self.action_custom_quit()
|
|
return
|
|
|
|
if len(self.screen_stack) > 1:
|
|
return
|
|
|
|
try:
|
|
self.query_one("#main_container")
|
|
except (ValueError, Exception):
|
|
self.action_custom_quit()
|
|
return
|
|
|
|
self.push_screen(QuitScreen())
|
|
|
|
def action_stop_selected_agent(self) -> None:
|
|
if self.show_splash or not self.is_mounted:
|
|
return
|
|
|
|
if len(self.screen_stack) > 1:
|
|
self.pop_screen()
|
|
return
|
|
|
|
if not self.selected_agent_id:
|
|
return
|
|
|
|
agent_name, should_stop = self._validate_agent_for_stopping()
|
|
if not should_stop:
|
|
return
|
|
|
|
try:
|
|
self.query_one("#main_container")
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
self.push_screen(StopAgentScreen(agent_name, self.selected_agent_id))
|
|
|
|
def _validate_agent_for_stopping(self) -> tuple[str, bool]:
|
|
agent_name = "Unknown Agent"
|
|
|
|
try:
|
|
if self.tracer and self.selected_agent_id in self.tracer.agents:
|
|
agent_data = self.tracer.agents[self.selected_agent_id]
|
|
agent_name = agent_data.get("name", "Unknown Agent")
|
|
|
|
agent_status = agent_data.get("status", "running")
|
|
if agent_status not in ["running"]:
|
|
return agent_name, False
|
|
|
|
agent_events = self._gather_agent_events(self.selected_agent_id)
|
|
if not agent_events:
|
|
return agent_name, False
|
|
|
|
return agent_name, True
|
|
|
|
except (KeyError, AttributeError, ValueError) as e:
|
|
import logging
|
|
|
|
logging.warning(f"Failed to gather agent events: {e}")
|
|
|
|
return agent_name, False
|
|
|
|
def action_confirm_stop_agent(self, agent_id: str) -> None:
|
|
try:
|
|
from strix.tools.agents_graph.agents_graph_actions import stop_agent
|
|
|
|
result = stop_agent(agent_id)
|
|
|
|
import logging
|
|
|
|
if result.get("success"):
|
|
logging.info(f"Stop request sent to agent: {result.get('message', 'Unknown')}")
|
|
else:
|
|
logging.warning(f"Failed to stop agent: {result.get('error', 'Unknown error')}")
|
|
|
|
except Exception:
|
|
import logging
|
|
|
|
logging.exception(f"Failed to stop agent {agent_id}")
|
|
|
|
def action_custom_quit(self) -> None:
|
|
if self._scan_thread and self._scan_thread.is_alive():
|
|
self._scan_stop_event.set()
|
|
|
|
self._scan_thread.join(timeout=1.0)
|
|
|
|
self.tracer.cleanup()
|
|
|
|
self.exit()
|
|
|
|
def _is_widget_safe(self, widget: Any) -> bool:
|
|
try:
|
|
_ = widget.screen
|
|
except (AttributeError, ValueError, Exception):
|
|
return False
|
|
else:
|
|
return bool(widget.is_mounted)
|
|
|
|
def _safe_widget_operation(
|
|
self, operation: Callable[..., Any], *args: Any, **kwargs: Any
|
|
) -> bool:
|
|
try:
|
|
operation(*args, **kwargs)
|
|
except (AttributeError, ValueError, Exception):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def on_resize(self, event: events.Resize) -> None:
|
|
if self.show_splash or not self.is_mounted:
|
|
return
|
|
|
|
try:
|
|
sidebar = self.query_one("#sidebar", Vertical)
|
|
chat_area = self.query_one("#chat_area_container", Vertical)
|
|
except (ValueError, Exception):
|
|
return
|
|
|
|
if event.size.width < self.SIDEBAR_MIN_WIDTH:
|
|
sidebar.add_class("-hidden")
|
|
chat_area.add_class("-full-width")
|
|
else:
|
|
sidebar.remove_class("-hidden")
|
|
chat_area.remove_class("-full-width")
|
|
|
|
def on_mouse_up(self, _event: events.MouseUp) -> None:
|
|
self.set_timer(0.05, self._auto_copy_selection)
|
|
|
|
_ICON_PREFIXES: ClassVar[tuple[str, ...]] = (
|
|
"🐞 ",
|
|
"🌐 ",
|
|
"📋 ",
|
|
"🧠 ",
|
|
"◆ ",
|
|
"◇ ",
|
|
"◈ ",
|
|
"→ ",
|
|
"○ ",
|
|
"● ",
|
|
"✓ ",
|
|
"✗ ",
|
|
"⚠ ",
|
|
"▍ ",
|
|
"▍",
|
|
"┃ ",
|
|
"• ",
|
|
">_ ",
|
|
"</> ",
|
|
"<~> ",
|
|
"[ ] ",
|
|
"[~] ",
|
|
"[•] ",
|
|
)
|
|
|
|
_DECORATIVE_LINES: ClassVar[frozenset[str]] = frozenset(
|
|
{
|
|
"● In progress...",
|
|
"✓ Done",
|
|
"✗ Failed",
|
|
"✗ Error",
|
|
"○ Unknown",
|
|
}
|
|
)
|
|
|
|
@staticmethod
|
|
def _clean_copied_text(text: str) -> str:
|
|
lines = text.split("\n")
|
|
cleaned: list[str] = []
|
|
for line in lines:
|
|
stripped = line.lstrip()
|
|
if stripped in StrixTUIApp._DECORATIVE_LINES:
|
|
continue
|
|
if stripped and all(c == "─" for c in stripped):
|
|
continue
|
|
out = line
|
|
for prefix in StrixTUIApp._ICON_PREFIXES:
|
|
if stripped.startswith(prefix):
|
|
leading = line[: len(line) - len(line.lstrip())]
|
|
out = leading + stripped[len(prefix) :]
|
|
break
|
|
cleaned.append(out)
|
|
return "\n".join(cleaned)
|
|
|
|
def _auto_copy_selection(self) -> None:
|
|
copied = False
|
|
|
|
try:
|
|
if self.screen.selections:
|
|
selected = self.screen.get_selected_text()
|
|
self.screen.clear_selection()
|
|
if selected and selected.strip():
|
|
cleaned = self._clean_copied_text(selected)
|
|
self.copy_to_clipboard(cleaned if cleaned.strip() else selected)
|
|
copied = True
|
|
except Exception: # noqa: BLE001
|
|
logger.debug("Failed to copy screen selection", exc_info=True)
|
|
|
|
if not copied:
|
|
try:
|
|
chat_input = self.query_one("#chat_input", ChatTextArea)
|
|
selected = chat_input.selected_text
|
|
if selected and selected.strip():
|
|
self.copy_to_clipboard(selected)
|
|
chat_input.move_cursor(chat_input.cursor_location)
|
|
copied = True
|
|
except Exception: # noqa: BLE001
|
|
logger.debug("Failed to copy chat input selection", exc_info=True)
|
|
|
|
if copied:
|
|
self.notify("Copied to clipboard", timeout=2)
|
|
|
|
|
|
async def run_tui(args: argparse.Namespace) -> None:
|
|
"""Run strix in interactive TUI mode with textual."""
|
|
app = StrixTUIApp(args)
|
|
await app.run_async()
|