Real-time display panel for agent stats (#134)

Co-authored-by: Ahmed Allam <ahmed39652003@gmail.com>
This commit is contained in:
Alexander De Battista Kvamme
2025-11-25 13:06:20 +01:00
committed by GitHub
parent 78d0148d58
commit c0e547928e
5 changed files with 226 additions and 49 deletions

View File

@@ -33,18 +33,32 @@ Screen {
background: transparent; background: transparent;
} }
#sidebar {
width: 25%;
background: transparent;
margin-left: 1;
}
#agents_tree { #agents_tree {
width: 20%; height: 1fr;
background: transparent; background: transparent;
border: round #262626; border: round #262626;
border-title-color: #a8a29e; border-title-color: #a8a29e;
border-title-style: bold; border-title-style: bold;
margin-left: 1;
padding: 1; padding: 1;
margin-bottom: 0;
}
#stats_display {
height: auto;
max-height: 15;
background: transparent;
padding: 0;
margin: 0;
} }
#chat_area_container { #chat_area_container {
width: 80%; width: 75%;
background: transparent; background: transparent;
} }

View File

@@ -1,9 +1,12 @@
import atexit import atexit
import signal import signal
import sys import sys
import threading
import time
from typing import Any from typing import Any
from rich.console import Console from rich.console import Console
from rich.live import Live
from rich.panel import Panel from rich.panel import Panel
from rich.text import Text from rich.text import Text
@@ -11,7 +14,7 @@ from strix.agents.StrixAgent import StrixAgent
from strix.llm.config import LLMConfig from strix.llm.config import LLMConfig
from strix.telemetry.tracer import Tracer, set_global_tracer from strix.telemetry.tracer import Tracer, set_global_tracer
from .utils import get_severity_color from .utils import build_final_stats_text, build_live_stats_text, get_severity_color
async def run_cli(args: Any) -> None: # noqa: PLR0915 async def run_cli(args: Any) -> None: # noqa: PLR0915
@@ -130,12 +133,46 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
set_global_tracer(tracer) set_global_tracer(tracer)
def create_live_status() -> Panel:
status_text = Text()
status_text.append("🦉 ", style="bold white")
status_text.append("Running penetration test...", style="bold #22c55e")
status_text.append("\n\n")
stats_text = build_live_stats_text(tracer)
if stats_text:
status_text.append(stats_text)
return Panel(
status_text,
title="[bold #22c55e]🔍 Live Penetration Test Status",
title_align="center",
border_style="#22c55e",
padding=(1, 2),
)
try: try:
console.print() console.print()
with console.status("[bold cyan]Running penetration test...", spinner="dots") as status:
with Live(
create_live_status(), console=console, refresh_per_second=2, transient=False
) as live:
stop_updates = threading.Event()
def update_status() -> None:
while not stop_updates.is_set():
try:
live.update(create_live_status())
time.sleep(2)
except Exception: # noqa: BLE001
break
update_thread = threading.Thread(target=update_status, daemon=True)
update_thread.start()
try:
agent = StrixAgent(agent_config) agent = StrixAgent(agent_config)
result = await agent.execute_scan(scan_config) result = await agent.execute_scan(scan_config)
status.stop()
if isinstance(result, dict) and not result.get("success", True): if isinstance(result, dict) and not result.get("success", True):
error_msg = result.get("error", "Unknown error") error_msg = result.get("error", "Unknown error")
@@ -143,11 +180,33 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
console.print(f"[bold red]❌ Penetration test failed:[/] {error_msg}") console.print(f"[bold red]❌ Penetration test failed:[/] {error_msg}")
console.print() console.print()
sys.exit(1) sys.exit(1)
finally:
stop_updates.set()
update_thread.join(timeout=1)
except Exception as e: except Exception as e:
console.print(f"[bold red]Error during penetration test:[/] {e}") console.print(f"[bold red]Error during penetration test:[/] {e}")
raise raise
console.print()
final_stats_text = Text()
final_stats_text.append("📊 ", style="bold cyan")
final_stats_text.append("PENETRATION TEST COMPLETED", style="bold green")
final_stats_text.append("\n\n")
stats_text = build_final_stats_text(tracer)
if stats_text:
final_stats_text.append(stats_text)
final_stats_panel = Panel(
final_stats_text,
title="[bold green]✅ Final Statistics",
title_align="center",
border_style="green",
padding=(1, 2),
)
console.print(final_stats_panel)
if tracer.final_scan_result: if tracer.final_scan_result:
console.print() console.print()

View File

@@ -21,8 +21,7 @@ from strix.interface.cli import run_cli
from strix.interface.tui import run_tui from strix.interface.tui import run_tui
from strix.interface.utils import ( from strix.interface.utils import (
assign_workspace_subdirs, assign_workspace_subdirs,
build_llm_stats_text, build_final_stats_text,
build_stats_text,
check_docker_connection, check_docker_connection,
clone_repository, clone_repository,
collect_local_sources, collect_local_sources,
@@ -370,8 +369,7 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
completion_text.append("", style="dim white") completion_text.append("", style="dim white")
completion_text.append("Penetration test interrupted by user", style="white") completion_text.append("Penetration test interrupted by user", style="white")
stats_text = build_stats_text(tracer) stats_text = build_final_stats_text(tracer)
llm_stats_text = build_llm_stats_text(tracer)
target_text = Text() target_text = Text()
if len(args.targets_info) == 1: if len(args.targets_info) == 1:
@@ -391,9 +389,6 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
if stats_text.plain: if stats_text.plain:
panel_parts.extend(["\n", stats_text]) panel_parts.extend(["\n", stats_text])
if llm_stats_text.plain:
panel_parts.extend(["\n", llm_stats_text])
if scan_completed or has_vulnerabilities: if scan_completed or has_vulnerabilities:
results_text = Text() results_text = Text()
results_text.append("📊 Results Saved To: ", style="bold cyan") results_text.append("📊 Results Saved To: ", style="bold cyan")

View File

@@ -31,6 +31,7 @@ from textual.widgets import Button, Label, Static, TextArea, Tree
from textual.widgets.tree import TreeNode from textual.widgets.tree import TreeNode
from strix.agents.StrixAgent import StrixAgent from strix.agents.StrixAgent import StrixAgent
from strix.interface.utils import build_live_stats_text
from strix.llm.config import LLMConfig from strix.llm.config import LLMConfig
from strix.telemetry.tracer import Tracer, set_global_tracer from strix.telemetry.tracer import Tracer, set_global_tracer
@@ -393,8 +394,12 @@ class StrixTUIApp(App): # type: ignore[misc]
agents_tree.guide_depth = 3 agents_tree.guide_depth = 3
agents_tree.guide_style = "dashed" agents_tree.guide_style = "dashed"
stats_display = Static("", id="stats_display")
sidebar = Vertical(agents_tree, stats_display, id="sidebar")
content_container.mount(chat_area_container) content_container.mount(chat_area_container)
content_container.mount(agents_tree) content_container.mount(sidebar)
chat_area_container.mount(chat_history) chat_area_container.mount(chat_history)
chat_area_container.mount(agent_status_display) chat_area_container.mount(agent_status_display)
@@ -481,6 +486,8 @@ class StrixTUIApp(App): # type: ignore[misc]
self._update_agent_status_display() self._update_agent_status_display()
self._update_stats_display()
def _update_agent_node(self, agent_id: str, agent_data: dict[str, Any]) -> bool: def _update_agent_node(self, agent_id: str, agent_data: dict[str, Any]) -> bool:
if agent_id not in self.agent_nodes: if agent_id not in self.agent_nodes:
return False return False
@@ -658,6 +665,33 @@ class StrixTUIApp(App): # type: ignore[misc]
except (KeyError, Exception): except (KeyError, Exception):
self._safe_widget_operation(status_display.add_class, "hidden") self._safe_widget_operation(status_display.add_class, "hidden")
def _update_stats_display(self) -> None:
try:
stats_display = self.query_one("#stats_display", Static)
except (ValueError, Exception):
return
if not self._is_widget_safe(stats_display):
return
stats_content = Text()
stats_text = build_live_stats_text(self.tracer)
if stats_text:
stats_content.append(stats_text)
from rich.panel import Panel
stats_panel = Panel(
stats_content,
title="📊 Live Stats",
title_align="left",
border_style="#22c55e",
padding=(0, 1),
)
self._safe_widget_operation(stats_display.update, stats_panel)
def _get_agent_verb(self, agent_id: str) -> str: def _get_agent_verb(self, agent_id: str) -> str:
if agent_id not in self._agent_verbs: if agent_id not in self._agent_verbs:
self._agent_verbs[agent_id] = random.choice(self._action_verbs) # nosec B311 # noqa: S311 self._agent_verbs[agent_id] = random.choice(self._action_verbs) # nosec B311 # noqa: S311

View File

@@ -38,14 +38,9 @@ def get_severity_color(severity: str) -> str:
return severity_colors.get(severity, "#6b7280") return severity_colors.get(severity, "#6b7280")
def build_stats_text(tracer: Any) -> Text: def _build_vulnerability_stats(stats_text: Text, tracer: Any) -> None:
stats_text = Text() """Build vulnerability section of stats text."""
if not tracer:
return stats_text
vuln_count = len(tracer.vulnerability_reports) vuln_count = len(tracer.vulnerability_reports)
tool_count = tracer.get_real_tool_count()
agent_count = len(tracer.agents)
if vuln_count > 0: if vuln_count > 0:
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
@@ -81,44 +76,124 @@ def build_stats_text(tracer: Any) -> Text:
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green") stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
stats_text.append("\n") stats_text.append("\n")
def _build_llm_stats(stats_text: Text, total_stats: dict[str, Any]) -> None:
"""Build LLM usage section of stats text."""
if total_stats["requests"] > 0:
stats_text.append("\n")
stats_text.append("📥 Input Tokens: ", style="bold cyan")
stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
if total_stats["cached_tokens"] > 0:
stats_text.append("", style="dim white")
stats_text.append("⚡ Cached Tokens: ", style="bold green")
stats_text.append(format_token_count(total_stats["cached_tokens"]), style="bold white")
stats_text.append("", style="dim white")
stats_text.append("📤 Output Tokens: ", style="bold cyan")
stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
if total_stats["cost"] > 0:
stats_text.append("", style="dim white")
stats_text.append("💰 Total Cost: ", style="bold cyan")
stats_text.append(f"${total_stats['cost']:.4f}", style="bold yellow")
else:
stats_text.append("\n")
stats_text.append("💰 Total Cost: ", style="bold cyan")
stats_text.append("$0.0000 ", style="bold yellow")
stats_text.append("", style="bold white")
stats_text.append("📊 Tokens: ", style="bold cyan")
stats_text.append("0", style="bold white")
def build_final_stats_text(tracer: Any) -> Text:
"""Build stats text for final output with detailed messages and LLM usage."""
stats_text = Text()
if not tracer:
return stats_text
_build_vulnerability_stats(stats_text, tracer)
tool_count = tracer.get_real_tool_count()
agent_count = len(tracer.agents)
stats_text.append("🤖 Agents Used: ", style="bold cyan") stats_text.append("🤖 Agents Used: ", style="bold cyan")
stats_text.append(str(agent_count), style="bold white") stats_text.append(str(agent_count), style="bold white")
stats_text.append("", style="dim white") stats_text.append("", style="dim white")
stats_text.append("🛠️ Tools Called: ", style="bold cyan") stats_text.append("🛠️ Tools Called: ", style="bold cyan")
stats_text.append(str(tool_count), style="bold white") stats_text.append(str(tool_count), style="bold white")
llm_stats = tracer.get_total_llm_stats()
_build_llm_stats(stats_text, llm_stats["total"])
return stats_text return stats_text
def build_llm_stats_text(tracer: Any) -> Text: def build_live_stats_text(tracer: Any) -> Text:
llm_stats_text = Text() stats_text = Text()
if not tracer: if not tracer:
return llm_stats_text return stats_text
vuln_count = len(tracer.vulnerability_reports)
tool_count = tracer.get_real_tool_count()
agent_count = len(tracer.agents)
stats_text.append("🔍 Vulnerabilities: ", style="bold white")
stats_text.append(f"{vuln_count}", style="dim white")
stats_text.append("\n")
if vuln_count > 0:
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
for report in tracer.vulnerability_reports:
severity = report.get("severity", "").lower()
if severity in severity_counts:
severity_counts[severity] += 1
severity_parts = []
for severity in ["critical", "high", "medium", "low", "info"]:
count = severity_counts[severity]
if count > 0:
severity_color = get_severity_color(severity)
severity_text = Text()
severity_text.append(f"{severity.upper()}: ", style=severity_color)
severity_text.append(str(count), style=f"bold {severity_color}")
severity_parts.append(severity_text)
for i, part in enumerate(severity_parts):
stats_text.append(part)
if i < len(severity_parts) - 1:
stats_text.append(" | ", style="dim white")
stats_text.append("\n")
stats_text.append("🤖 Agents: ", style="bold white")
stats_text.append(str(agent_count), style="dim white")
stats_text.append("", style="dim white")
stats_text.append("🛠️ Tools: ", style="bold white")
stats_text.append(str(tool_count), style="dim white")
llm_stats = tracer.get_total_llm_stats() llm_stats = tracer.get_total_llm_stats()
total_stats = llm_stats["total"] total_stats = llm_stats["total"]
if total_stats["requests"] > 0: stats_text.append("\n")
llm_stats_text.append("📥 Input Tokens: ", style="bold cyan")
llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
if total_stats["cached_tokens"] > 0: stats_text.append("📥 Input: ", style="bold white")
llm_stats_text.append("", style="dim white") stats_text.append(format_token_count(total_stats["input_tokens"]), style="dim white")
llm_stats_text.append("⚡ Cached: ", style="bold green")
llm_stats_text.append(
format_token_count(total_stats["cached_tokens"]), style="bold green"
)
llm_stats_text.append("", style="dim white") stats_text.append("", style="dim white")
llm_stats_text.append("📤 Output Tokens: ", style="bold cyan") stats_text.append(" ", style="bold white")
llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white") stats_text.append("Cached: ", style="bold white")
stats_text.append(format_token_count(total_stats["cached_tokens"]), style="dim white")
if total_stats["cost"] > 0: stats_text.append("\n")
llm_stats_text.append("", style="dim white")
llm_stats_text.append("💰 Total Cost: $", style="bold cyan")
llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow")
return llm_stats_text stats_text.append("📤 Output: ", style="bold white")
stats_text.append(format_token_count(total_stats["output_tokens"]), style="dim white")
stats_text.append("", style="dim white")
stats_text.append("💰 Cost: ", style="bold white")
stats_text.append(f"${total_stats['cost']:.4f}", style="dim white")
return stats_text
# Name generation utilities # Name generation utilities