Real-time display panel for agent stats (#134)
Co-authored-by: Ahmed Allam <ahmed39652003@gmail.com>
This commit is contained in:
committed by
GitHub
parent
78d0148d58
commit
c0e547928e
@@ -33,18 +33,32 @@ Screen {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
#sidebar {
|
||||
width: 25%;
|
||||
background: transparent;
|
||||
margin-left: 1;
|
||||
}
|
||||
|
||||
#agents_tree {
|
||||
width: 20%;
|
||||
height: 1fr;
|
||||
background: transparent;
|
||||
border: round #262626;
|
||||
border-title-color: #a8a29e;
|
||||
border-title-style: bold;
|
||||
margin-left: 1;
|
||||
padding: 1;
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
#stats_display {
|
||||
height: auto;
|
||||
max-height: 15;
|
||||
background: transparent;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
#chat_area_container {
|
||||
width: 80%;
|
||||
width: 75%;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import atexit
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
@@ -11,7 +14,7 @@ from strix.agents.StrixAgent import StrixAgent
|
||||
from strix.llm.config import LLMConfig
|
||||
from strix.telemetry.tracer import Tracer, set_global_tracer
|
||||
|
||||
from .utils import get_severity_color
|
||||
from .utils import build_final_stats_text, build_live_stats_text, get_severity_color
|
||||
|
||||
|
||||
async def run_cli(args: Any) -> None: # noqa: PLR0915
|
||||
@@ -130,24 +133,80 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
||||
|
||||
set_global_tracer(tracer)
|
||||
|
||||
def create_live_status() -> Panel:
|
||||
status_text = Text()
|
||||
status_text.append("🦉 ", style="bold white")
|
||||
status_text.append("Running penetration test...", style="bold #22c55e")
|
||||
status_text.append("\n\n")
|
||||
|
||||
stats_text = build_live_stats_text(tracer)
|
||||
if stats_text:
|
||||
status_text.append(stats_text)
|
||||
|
||||
return Panel(
|
||||
status_text,
|
||||
title="[bold #22c55e]🔍 Live Penetration Test Status",
|
||||
title_align="center",
|
||||
border_style="#22c55e",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
try:
|
||||
console.print()
|
||||
with console.status("[bold cyan]Running penetration test...", spinner="dots") as status:
|
||||
agent = StrixAgent(agent_config)
|
||||
result = await agent.execute_scan(scan_config)
|
||||
status.stop()
|
||||
|
||||
if isinstance(result, dict) and not result.get("success", True):
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
console.print()
|
||||
console.print(f"[bold red]❌ Penetration test failed:[/] {error_msg}")
|
||||
console.print()
|
||||
sys.exit(1)
|
||||
with Live(
|
||||
create_live_status(), console=console, refresh_per_second=2, transient=False
|
||||
) as live:
|
||||
stop_updates = threading.Event()
|
||||
|
||||
def update_status() -> None:
|
||||
while not stop_updates.is_set():
|
||||
try:
|
||||
live.update(create_live_status())
|
||||
time.sleep(2)
|
||||
except Exception: # noqa: BLE001
|
||||
break
|
||||
|
||||
update_thread = threading.Thread(target=update_status, daemon=True)
|
||||
update_thread.start()
|
||||
|
||||
try:
|
||||
agent = StrixAgent(agent_config)
|
||||
result = await agent.execute_scan(scan_config)
|
||||
|
||||
if isinstance(result, dict) and not result.get("success", True):
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
console.print()
|
||||
console.print(f"[bold red]❌ Penetration test failed:[/] {error_msg}")
|
||||
console.print()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
stop_updates.set()
|
||||
update_thread.join(timeout=1)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error during penetration test:[/] {e}")
|
||||
raise
|
||||
|
||||
console.print()
|
||||
final_stats_text = Text()
|
||||
final_stats_text.append("📊 ", style="bold cyan")
|
||||
final_stats_text.append("PENETRATION TEST COMPLETED", style="bold green")
|
||||
final_stats_text.append("\n\n")
|
||||
|
||||
stats_text = build_final_stats_text(tracer)
|
||||
if stats_text:
|
||||
final_stats_text.append(stats_text)
|
||||
|
||||
final_stats_panel = Panel(
|
||||
final_stats_text,
|
||||
title="[bold green]✅ Final Statistics",
|
||||
title_align="center",
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(final_stats_panel)
|
||||
|
||||
if tracer.final_scan_result:
|
||||
console.print()
|
||||
|
||||
|
||||
@@ -21,8 +21,7 @@ from strix.interface.cli import run_cli
|
||||
from strix.interface.tui import run_tui
|
||||
from strix.interface.utils import (
|
||||
assign_workspace_subdirs,
|
||||
build_llm_stats_text,
|
||||
build_stats_text,
|
||||
build_final_stats_text,
|
||||
check_docker_connection,
|
||||
clone_repository,
|
||||
collect_local_sources,
|
||||
@@ -370,8 +369,7 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
|
||||
completion_text.append(" • ", style="dim white")
|
||||
completion_text.append("Penetration test interrupted by user", style="white")
|
||||
|
||||
stats_text = build_stats_text(tracer)
|
||||
llm_stats_text = build_llm_stats_text(tracer)
|
||||
stats_text = build_final_stats_text(tracer)
|
||||
|
||||
target_text = Text()
|
||||
if len(args.targets_info) == 1:
|
||||
@@ -391,9 +389,6 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
|
||||
if stats_text.plain:
|
||||
panel_parts.extend(["\n", stats_text])
|
||||
|
||||
if llm_stats_text.plain:
|
||||
panel_parts.extend(["\n", llm_stats_text])
|
||||
|
||||
if scan_completed or has_vulnerabilities:
|
||||
results_text = Text()
|
||||
results_text.append("📊 Results Saved To: ", style="bold cyan")
|
||||
|
||||
@@ -31,6 +31,7 @@ from textual.widgets import Button, Label, Static, TextArea, Tree
|
||||
from textual.widgets.tree import TreeNode
|
||||
|
||||
from strix.agents.StrixAgent import StrixAgent
|
||||
from strix.interface.utils import build_live_stats_text
|
||||
from strix.llm.config import LLMConfig
|
||||
from strix.telemetry.tracer import Tracer, set_global_tracer
|
||||
|
||||
@@ -393,8 +394,12 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
agents_tree.guide_depth = 3
|
||||
agents_tree.guide_style = "dashed"
|
||||
|
||||
stats_display = Static("", id="stats_display")
|
||||
|
||||
sidebar = Vertical(agents_tree, stats_display, id="sidebar")
|
||||
|
||||
content_container.mount(chat_area_container)
|
||||
content_container.mount(agents_tree)
|
||||
content_container.mount(sidebar)
|
||||
|
||||
chat_area_container.mount(chat_history)
|
||||
chat_area_container.mount(agent_status_display)
|
||||
@@ -481,6 +486,8 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
|
||||
self._update_agent_status_display()
|
||||
|
||||
self._update_stats_display()
|
||||
|
||||
def _update_agent_node(self, agent_id: str, agent_data: dict[str, Any]) -> bool:
|
||||
if agent_id not in self.agent_nodes:
|
||||
return False
|
||||
@@ -658,6 +665,33 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
except (KeyError, Exception):
|
||||
self._safe_widget_operation(status_display.add_class, "hidden")
|
||||
|
||||
def _update_stats_display(self) -> None:
|
||||
try:
|
||||
stats_display = self.query_one("#stats_display", Static)
|
||||
except (ValueError, Exception):
|
||||
return
|
||||
|
||||
if not self._is_widget_safe(stats_display):
|
||||
return
|
||||
|
||||
stats_content = Text()
|
||||
|
||||
stats_text = build_live_stats_text(self.tracer)
|
||||
if stats_text:
|
||||
stats_content.append(stats_text)
|
||||
|
||||
from rich.panel import Panel
|
||||
|
||||
stats_panel = Panel(
|
||||
stats_content,
|
||||
title="📊 Live Stats",
|
||||
title_align="left",
|
||||
border_style="#22c55e",
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
self._safe_widget_operation(stats_display.update, stats_panel)
|
||||
|
||||
def _get_agent_verb(self, agent_id: str) -> str:
|
||||
if agent_id not in self._agent_verbs:
|
||||
self._agent_verbs[agent_id] = random.choice(self._action_verbs) # nosec B311 # noqa: S311
|
||||
|
||||
@@ -38,14 +38,9 @@ def get_severity_color(severity: str) -> str:
|
||||
return severity_colors.get(severity, "#6b7280")
|
||||
|
||||
|
||||
def build_stats_text(tracer: Any) -> Text:
|
||||
stats_text = Text()
|
||||
if not tracer:
|
||||
return stats_text
|
||||
|
||||
def _build_vulnerability_stats(stats_text: Text, tracer: Any) -> None:
|
||||
"""Build vulnerability section of stats text."""
|
||||
vuln_count = len(tracer.vulnerability_reports)
|
||||
tool_count = tracer.get_real_tool_count()
|
||||
agent_count = len(tracer.agents)
|
||||
|
||||
if vuln_count > 0:
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||
@@ -81,44 +76,124 @@ def build_stats_text(tracer: Any) -> Text:
|
||||
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
|
||||
stats_text.append("\n")
|
||||
|
||||
|
||||
def _build_llm_stats(stats_text: Text, total_stats: dict[str, Any]) -> None:
|
||||
"""Build LLM usage section of stats text."""
|
||||
if total_stats["requests"] > 0:
|
||||
stats_text.append("\n")
|
||||
stats_text.append("📥 Input Tokens: ", style="bold cyan")
|
||||
stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
|
||||
|
||||
if total_stats["cached_tokens"] > 0:
|
||||
stats_text.append(" • ", style="dim white")
|
||||
stats_text.append("⚡ Cached Tokens: ", style="bold green")
|
||||
stats_text.append(format_token_count(total_stats["cached_tokens"]), style="bold white")
|
||||
|
||||
stats_text.append(" • ", style="dim white")
|
||||
stats_text.append("📤 Output Tokens: ", style="bold cyan")
|
||||
stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
|
||||
|
||||
if total_stats["cost"] > 0:
|
||||
stats_text.append(" • ", style="dim white")
|
||||
stats_text.append("💰 Total Cost: ", style="bold cyan")
|
||||
stats_text.append(f"${total_stats['cost']:.4f}", style="bold yellow")
|
||||
else:
|
||||
stats_text.append("\n")
|
||||
stats_text.append("💰 Total Cost: ", style="bold cyan")
|
||||
stats_text.append("$0.0000 ", style="bold yellow")
|
||||
stats_text.append("• ", style="bold white")
|
||||
stats_text.append("📊 Tokens: ", style="bold cyan")
|
||||
stats_text.append("0", style="bold white")
|
||||
|
||||
|
||||
def build_final_stats_text(tracer: Any) -> Text:
|
||||
"""Build stats text for final output with detailed messages and LLM usage."""
|
||||
stats_text = Text()
|
||||
if not tracer:
|
||||
return stats_text
|
||||
|
||||
_build_vulnerability_stats(stats_text, tracer)
|
||||
|
||||
tool_count = tracer.get_real_tool_count()
|
||||
agent_count = len(tracer.agents)
|
||||
|
||||
stats_text.append("🤖 Agents Used: ", style="bold cyan")
|
||||
stats_text.append(str(agent_count), style="bold white")
|
||||
stats_text.append(" • ", style="dim white")
|
||||
stats_text.append("🛠️ Tools Called: ", style="bold cyan")
|
||||
stats_text.append(str(tool_count), style="bold white")
|
||||
|
||||
llm_stats = tracer.get_total_llm_stats()
|
||||
_build_llm_stats(stats_text, llm_stats["total"])
|
||||
|
||||
return stats_text
|
||||
|
||||
|
||||
def build_llm_stats_text(tracer: Any) -> Text:
|
||||
llm_stats_text = Text()
|
||||
def build_live_stats_text(tracer: Any) -> Text:
|
||||
stats_text = Text()
|
||||
if not tracer:
|
||||
return llm_stats_text
|
||||
return stats_text
|
||||
|
||||
vuln_count = len(tracer.vulnerability_reports)
|
||||
tool_count = tracer.get_real_tool_count()
|
||||
agent_count = len(tracer.agents)
|
||||
|
||||
stats_text.append("🔍 Vulnerabilities: ", style="bold white")
|
||||
stats_text.append(f"{vuln_count}", style="dim white")
|
||||
stats_text.append("\n")
|
||||
if vuln_count > 0:
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||
for report in tracer.vulnerability_reports:
|
||||
severity = report.get("severity", "").lower()
|
||||
if severity in severity_counts:
|
||||
severity_counts[severity] += 1
|
||||
|
||||
severity_parts = []
|
||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||
count = severity_counts[severity]
|
||||
if count > 0:
|
||||
severity_color = get_severity_color(severity)
|
||||
severity_text = Text()
|
||||
severity_text.append(f"{severity.upper()}: ", style=severity_color)
|
||||
severity_text.append(str(count), style=f"bold {severity_color}")
|
||||
severity_parts.append(severity_text)
|
||||
|
||||
for i, part in enumerate(severity_parts):
|
||||
stats_text.append(part)
|
||||
if i < len(severity_parts) - 1:
|
||||
stats_text.append(" | ", style="dim white")
|
||||
|
||||
stats_text.append("\n")
|
||||
|
||||
stats_text.append("🤖 Agents: ", style="bold white")
|
||||
stats_text.append(str(agent_count), style="dim white")
|
||||
stats_text.append(" • ", style="dim white")
|
||||
stats_text.append("🛠️ Tools: ", style="bold white")
|
||||
stats_text.append(str(tool_count), style="dim white")
|
||||
|
||||
llm_stats = tracer.get_total_llm_stats()
|
||||
total_stats = llm_stats["total"]
|
||||
|
||||
if total_stats["requests"] > 0:
|
||||
llm_stats_text.append("📥 Input Tokens: ", style="bold cyan")
|
||||
llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
|
||||
stats_text.append("\n")
|
||||
|
||||
if total_stats["cached_tokens"] > 0:
|
||||
llm_stats_text.append(" • ", style="dim white")
|
||||
llm_stats_text.append("⚡ Cached: ", style="bold green")
|
||||
llm_stats_text.append(
|
||||
format_token_count(total_stats["cached_tokens"]), style="bold green"
|
||||
)
|
||||
stats_text.append("📥 Input: ", style="bold white")
|
||||
stats_text.append(format_token_count(total_stats["input_tokens"]), style="dim white")
|
||||
|
||||
llm_stats_text.append(" • ", style="dim white")
|
||||
llm_stats_text.append("📤 Output Tokens: ", style="bold cyan")
|
||||
llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
|
||||
stats_text.append(" • ", style="dim white")
|
||||
stats_text.append("⚡ ", style="bold white")
|
||||
stats_text.append("Cached: ", style="bold white")
|
||||
stats_text.append(format_token_count(total_stats["cached_tokens"]), style="dim white")
|
||||
|
||||
if total_stats["cost"] > 0:
|
||||
llm_stats_text.append(" • ", style="dim white")
|
||||
llm_stats_text.append("💰 Total Cost: $", style="bold cyan")
|
||||
llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow")
|
||||
stats_text.append("\n")
|
||||
|
||||
return llm_stats_text
|
||||
stats_text.append("📤 Output: ", style="bold white")
|
||||
stats_text.append(format_token_count(total_stats["output_tokens"]), style="dim white")
|
||||
|
||||
stats_text.append(" • ", style="dim white")
|
||||
stats_text.append("💰 Cost: ", style="bold white")
|
||||
stats_text.append(f"${total_stats['cost']:.4f}", style="dim white")
|
||||
|
||||
return stats_text
|
||||
|
||||
|
||||
# Name generation utilities
|
||||
|
||||
Reference in New Issue
Block a user