From c0e547928e90183fb6fa9154660bbcd5ed083950 Mon Sep 17 00:00:00 2001 From: Alexander De Battista Kvamme Date: Tue, 25 Nov 2025 13:06:20 +0100 Subject: [PATCH] Real-time display panel for agent stats (#134) Co-authored-by: Ahmed Allam --- strix/interface/assets/tui_styles.tcss | 20 +++- strix/interface/cli.py | 81 +++++++++++++--- strix/interface/main.py | 9 +- strix/interface/tui.py | 36 ++++++- strix/interface/utils.py | 129 +++++++++++++++++++------ 5 files changed, 226 insertions(+), 49 deletions(-) diff --git a/strix/interface/assets/tui_styles.tcss b/strix/interface/assets/tui_styles.tcss index 2707ec5..c9424f4 100644 --- a/strix/interface/assets/tui_styles.tcss +++ b/strix/interface/assets/tui_styles.tcss @@ -33,18 +33,32 @@ Screen { background: transparent; } +#sidebar { + width: 25%; + background: transparent; + margin-left: 1; +} + #agents_tree { - width: 20%; + height: 1fr; background: transparent; border: round #262626; border-title-color: #a8a29e; border-title-style: bold; - margin-left: 1; padding: 1; + margin-bottom: 0; +} + +#stats_display { + height: auto; + max-height: 15; + background: transparent; + padding: 0; + margin: 0; } #chat_area_container { - width: 80%; + width: 75%; background: transparent; } diff --git a/strix/interface/cli.py b/strix/interface/cli.py index 93952a6..626cbde 100644 --- a/strix/interface/cli.py +++ b/strix/interface/cli.py @@ -1,9 +1,12 @@ import atexit import signal import sys +import threading +import time from typing import Any from rich.console import Console +from rich.live import Live from rich.panel import Panel from rich.text import Text @@ -11,7 +14,7 @@ from strix.agents.StrixAgent import StrixAgent from strix.llm.config import LLMConfig from strix.telemetry.tracer import Tracer, set_global_tracer -from .utils import get_severity_color +from .utils import build_final_stats_text, build_live_stats_text, get_severity_color async def run_cli(args: Any) -> None: # noqa: PLR0915 @@ -130,24 +133,80 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 set_global_tracer(tracer) + def create_live_status() -> Panel: + status_text = Text() + status_text.append("🦉 ", style="bold white") + status_text.append("Running penetration test...", style="bold #22c55e") + status_text.append("\n\n") + + stats_text = build_live_stats_text(tracer) + if stats_text: + status_text.append(stats_text) + + return Panel( + status_text, + title="[bold #22c55e]🔍 Live Penetration Test Status", + title_align="center", + border_style="#22c55e", + padding=(1, 2), + ) + try: console.print() - with console.status("[bold cyan]Running penetration test...", spinner="dots") as status: - agent = StrixAgent(agent_config) - result = await agent.execute_scan(scan_config) - status.stop() - if isinstance(result, dict) and not result.get("success", True): - error_msg = result.get("error", "Unknown error") - console.print() - console.print(f"[bold red]❌ Penetration test failed:[/] {error_msg}") - console.print() - sys.exit(1) + with Live( + create_live_status(), console=console, refresh_per_second=2, transient=False + ) as live: + stop_updates = threading.Event() + + def update_status() -> None: + while not stop_updates.is_set(): + try: + live.update(create_live_status()) + time.sleep(2) + except Exception: # noqa: BLE001 + break + + update_thread = threading.Thread(target=update_status, daemon=True) + update_thread.start() + + try: + agent = StrixAgent(agent_config) + result = await agent.execute_scan(scan_config) + + if isinstance(result, dict) and not result.get("success", True): + error_msg = result.get("error", "Unknown error") + console.print() + console.print(f"[bold red]❌ Penetration test failed:[/] {error_msg}") + console.print() + sys.exit(1) + finally: + stop_updates.set() + update_thread.join(timeout=1) except Exception as e: console.print(f"[bold red]Error during penetration test:[/] {e}") raise + console.print() + final_stats_text = Text() + final_stats_text.append("📊 ", style="bold cyan") + final_stats_text.append("PENETRATION TEST COMPLETED", style="bold green") + final_stats_text.append("\n\n") + + stats_text = build_final_stats_text(tracer) + if stats_text: + final_stats_text.append(stats_text) + + final_stats_panel = Panel( + final_stats_text, + title="[bold green]✅ Final Statistics", + title_align="center", + border_style="green", + padding=(1, 2), + ) + console.print(final_stats_panel) + if tracer.final_scan_result: console.print() diff --git a/strix/interface/main.py b/strix/interface/main.py index 2b2166f..6c244cf 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -21,8 +21,7 @@ from strix.interface.cli import run_cli from strix.interface.tui import run_tui from strix.interface.utils import ( assign_workspace_subdirs, - build_llm_stats_text, - build_stats_text, + build_final_stats_text, check_docker_connection, clone_repository, collect_local_sources, @@ -370,8 +369,7 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) -> completion_text.append(" • ", style="dim white") completion_text.append("Penetration test interrupted by user", style="white") - stats_text = build_stats_text(tracer) - llm_stats_text = build_llm_stats_text(tracer) + stats_text = build_final_stats_text(tracer) target_text = Text() if len(args.targets_info) == 1: @@ -391,9 +389,6 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) -> if stats_text.plain: panel_parts.extend(["\n", stats_text]) - if llm_stats_text.plain: - panel_parts.extend(["\n", llm_stats_text]) - if scan_completed or has_vulnerabilities: results_text = Text() results_text.append("📊 Results Saved To: ", style="bold cyan") diff --git a/strix/interface/tui.py b/strix/interface/tui.py index ff0a255..1b0bc37 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -31,6 +31,7 @@ from textual.widgets import Button, Label, Static, TextArea, Tree from textual.widgets.tree import TreeNode from strix.agents.StrixAgent import StrixAgent +from strix.interface.utils import build_live_stats_text from strix.llm.config import LLMConfig from strix.telemetry.tracer import Tracer, set_global_tracer @@ -393,8 +394,12 @@ class StrixTUIApp(App): # type: ignore[misc] agents_tree.guide_depth = 3 agents_tree.guide_style = "dashed" + stats_display = Static("", id="stats_display") + + sidebar = Vertical(agents_tree, stats_display, id="sidebar") + content_container.mount(chat_area_container) - content_container.mount(agents_tree) + content_container.mount(sidebar) chat_area_container.mount(chat_history) chat_area_container.mount(agent_status_display) @@ -481,6 +486,8 @@ class StrixTUIApp(App): # type: ignore[misc] self._update_agent_status_display() + self._update_stats_display() + def _update_agent_node(self, agent_id: str, agent_data: dict[str, Any]) -> bool: if agent_id not in self.agent_nodes: return False @@ -658,6 +665,33 @@ class StrixTUIApp(App): # type: ignore[misc] except (KeyError, Exception): self._safe_widget_operation(status_display.add_class, "hidden") + def _update_stats_display(self) -> None: + try: + stats_display = self.query_one("#stats_display", Static) + except (ValueError, Exception): + return + + if not self._is_widget_safe(stats_display): + return + + stats_content = Text() + + stats_text = build_live_stats_text(self.tracer) + if stats_text: + stats_content.append(stats_text) + + from rich.panel import Panel + + stats_panel = Panel( + stats_content, + title="📊 Live Stats", + title_align="left", + border_style="#22c55e", + padding=(0, 1), + ) + + self._safe_widget_operation(stats_display.update, stats_panel) + def _get_agent_verb(self, agent_id: str) -> str: if agent_id not in self._agent_verbs: self._agent_verbs[agent_id] = random.choice(self._action_verbs) # nosec B311 # noqa: S311 diff --git a/strix/interface/utils.py b/strix/interface/utils.py index 8e973d6..6e3c56d 100644 --- a/strix/interface/utils.py +++ b/strix/interface/utils.py @@ -38,14 +38,9 @@ def get_severity_color(severity: str) -> str: return severity_colors.get(severity, "#6b7280") -def build_stats_text(tracer: Any) -> Text: - stats_text = Text() - if not tracer: - return stats_text - +def _build_vulnerability_stats(stats_text: Text, tracer: Any) -> None: + """Build vulnerability section of stats text.""" vuln_count = len(tracer.vulnerability_reports) - tool_count = tracer.get_real_tool_count() - agent_count = len(tracer.agents) if vuln_count > 0: severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} @@ -81,44 +76,124 @@ def build_stats_text(tracer: Any) -> Text: stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green") stats_text.append("\n") + +def _build_llm_stats(stats_text: Text, total_stats: dict[str, Any]) -> None: + """Build LLM usage section of stats text.""" + if total_stats["requests"] > 0: + stats_text.append("\n") + stats_text.append("📥 Input Tokens: ", style="bold cyan") + stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white") + + if total_stats["cached_tokens"] > 0: + stats_text.append(" • ", style="dim white") + stats_text.append("⚡ Cached Tokens: ", style="bold green") + stats_text.append(format_token_count(total_stats["cached_tokens"]), style="bold white") + + stats_text.append(" • ", style="dim white") + stats_text.append("📤 Output Tokens: ", style="bold cyan") + stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white") + + if total_stats["cost"] > 0: + stats_text.append(" • ", style="dim white") + stats_text.append("💰 Total Cost: ", style="bold cyan") + stats_text.append(f"${total_stats['cost']:.4f}", style="bold yellow") + else: + stats_text.append("\n") + stats_text.append("💰 Total Cost: ", style="bold cyan") + stats_text.append("$0.0000 ", style="bold yellow") + stats_text.append("• ", style="bold white") + stats_text.append("📊 Tokens: ", style="bold cyan") + stats_text.append("0", style="bold white") + + +def build_final_stats_text(tracer: Any) -> Text: + """Build stats text for final output with detailed messages and LLM usage.""" + stats_text = Text() + if not tracer: + return stats_text + + _build_vulnerability_stats(stats_text, tracer) + + tool_count = tracer.get_real_tool_count() + agent_count = len(tracer.agents) + stats_text.append("🤖 Agents Used: ", style="bold cyan") stats_text.append(str(agent_count), style="bold white") stats_text.append(" • ", style="dim white") stats_text.append("🛠️ Tools Called: ", style="bold cyan") stats_text.append(str(tool_count), style="bold white") + llm_stats = tracer.get_total_llm_stats() + _build_llm_stats(stats_text, llm_stats["total"]) + return stats_text -def build_llm_stats_text(tracer: Any) -> Text: - llm_stats_text = Text() +def build_live_stats_text(tracer: Any) -> Text: + stats_text = Text() if not tracer: - return llm_stats_text + return stats_text + + vuln_count = len(tracer.vulnerability_reports) + tool_count = tracer.get_real_tool_count() + agent_count = len(tracer.agents) + + stats_text.append("🔍 Vulnerabilities: ", style="bold white") + stats_text.append(f"{vuln_count}", style="dim white") + stats_text.append("\n") + if vuln_count > 0: + severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} + for report in tracer.vulnerability_reports: + severity = report.get("severity", "").lower() + if severity in severity_counts: + severity_counts[severity] += 1 + + severity_parts = [] + for severity in ["critical", "high", "medium", "low", "info"]: + count = severity_counts[severity] + if count > 0: + severity_color = get_severity_color(severity) + severity_text = Text() + severity_text.append(f"{severity.upper()}: ", style=severity_color) + severity_text.append(str(count), style=f"bold {severity_color}") + severity_parts.append(severity_text) + + for i, part in enumerate(severity_parts): + stats_text.append(part) + if i < len(severity_parts) - 1: + stats_text.append(" | ", style="dim white") + + stats_text.append("\n") + + stats_text.append("🤖 Agents: ", style="bold white") + stats_text.append(str(agent_count), style="dim white") + stats_text.append(" • ", style="dim white") + stats_text.append("🛠️ Tools: ", style="bold white") + stats_text.append(str(tool_count), style="dim white") llm_stats = tracer.get_total_llm_stats() total_stats = llm_stats["total"] - if total_stats["requests"] > 0: - llm_stats_text.append("📥 Input Tokens: ", style="bold cyan") - llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white") + stats_text.append("\n") - if total_stats["cached_tokens"] > 0: - llm_stats_text.append(" • ", style="dim white") - llm_stats_text.append("⚡ Cached: ", style="bold green") - llm_stats_text.append( - format_token_count(total_stats["cached_tokens"]), style="bold green" - ) + stats_text.append("📥 Input: ", style="bold white") + stats_text.append(format_token_count(total_stats["input_tokens"]), style="dim white") - llm_stats_text.append(" • ", style="dim white") - llm_stats_text.append("📤 Output Tokens: ", style="bold cyan") - llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white") + stats_text.append(" • ", style="dim white") + stats_text.append("⚡ ", style="bold white") + stats_text.append("Cached: ", style="bold white") + stats_text.append(format_token_count(total_stats["cached_tokens"]), style="dim white") - if total_stats["cost"] > 0: - llm_stats_text.append(" • ", style="dim white") - llm_stats_text.append("💰 Total Cost: $", style="bold cyan") - llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow") + stats_text.append("\n") - return llm_stats_text + stats_text.append("📤 Output: ", style="bold white") + stats_text.append(format_token_count(total_stats["output_tokens"]), style="dim white") + + stats_text.append(" • ", style="dim white") + stats_text.append("💰 Cost: ", style="bold white") + stats_text.append(f"${total_stats['cost']:.4f}", style="dim white") + + return stats_text # Name generation utilities