import atexit import signal import sys import threading import time from typing import Any from rich.console import Console from rich.live import Live from rich.panel import Panel from rich.text import Text from strix.agents.StrixAgent import StrixAgent from strix.llm.config import LLMConfig from strix.telemetry.tracer import Tracer, set_global_tracer from .utils import ( build_live_stats_text, format_vulnerability_report, ) async def run_cli(args: Any) -> None: # noqa: PLR0915 console = Console() start_text = Text() start_text.append("Penetration test initiated", style="bold #22c55e") target_text = Text() target_text.append("Target", style="dim") target_text.append(" ") if len(args.targets_info) == 1: target_text.append(args.targets_info[0]["original"], style="bold white") else: target_text.append(f"{len(args.targets_info)} targets", style="bold white") for target_info in args.targets_info: target_text.append("\n ") target_text.append(target_info["original"], style="white") results_text = Text() results_text.append("Output", style="dim") results_text.append(" ") results_text.append(f"strix_runs/{args.run_name}", style="#60a5fa") note_text = Text() note_text.append("\n\n", style="dim") note_text.append("Vulnerabilities will be displayed in real-time.", style="dim") startup_panel = Panel( Text.assemble( start_text, "\n\n", target_text, "\n", results_text, note_text, ), title="[bold white]STRIX", title_align="left", border_style="#22c55e", padding=(1, 2), ) console.print("\n") console.print(startup_panel) console.print() scan_mode = getattr(args, "scan_mode", "deep") scan_config = { "scan_id": args.run_name, "targets": args.targets_info, "user_instructions": args.instruction or "", "run_name": args.run_name, "diff_scope": getattr(args, "diff_scope", {"active": False}), } llm_config = LLMConfig( scan_mode=scan_mode, is_whitebox=bool(getattr(args, "local_sources", [])), ) agent_config = { "llm_config": llm_config, "max_iterations": 300, } if getattr(args, "local_sources", None): agent_config["local_sources"] = args.local_sources tracer = Tracer(args.run_name) tracer.set_scan_config(scan_config) def display_vulnerability(report: dict[str, Any]) -> None: report_id = report.get("id", "unknown") vuln_text = format_vulnerability_report(report) vuln_panel = Panel( vuln_text, title=f"[bold red]{report_id.upper()}", title_align="left", border_style="red", padding=(1, 2), ) console.print(vuln_panel) console.print() tracer.vulnerability_found_callback = display_vulnerability def cleanup_on_exit() -> None: from strix.runtime import cleanup_runtime tracer.cleanup() cleanup_runtime() def signal_handler(_signum: int, _frame: Any) -> None: tracer.cleanup() sys.exit(1) atexit.register(cleanup_on_exit) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) if hasattr(signal, "SIGHUP"): signal.signal(signal.SIGHUP, signal_handler) set_global_tracer(tracer) def create_live_status() -> Panel: status_text = Text() status_text.append("Penetration test in progress", style="bold #22c55e") status_text.append("\n\n") stats_text = build_live_stats_text(tracer, agent_config) if stats_text: status_text.append(stats_text) return Panel( status_text, title="[bold white]STRIX", title_align="left", border_style="#22c55e", padding=(1, 2), ) try: console.print() with Live( create_live_status(), console=console, refresh_per_second=2, transient=False ) as live: stop_updates = threading.Event() def update_status() -> None: while not stop_updates.is_set(): try: live.update(create_live_status()) time.sleep(2) except Exception: # noqa: BLE001 break update_thread = threading.Thread(target=update_status, daemon=True) update_thread.start() try: agent = StrixAgent(agent_config) result = await agent.execute_scan(scan_config) if isinstance(result, dict) and not result.get("success", True): error_msg = result.get("error", "Unknown error") error_details = result.get("details") console.print() console.print(f"[bold red]Penetration test failed:[/] {error_msg}") if error_details: console.print(f"[dim]{error_details}[/]") console.print() sys.exit(1) finally: stop_updates.set() update_thread.join(timeout=1) except Exception as e: console.print(f"[bold red]Error during penetration test:[/] {e}") raise if tracer.final_scan_result: console.print() final_report_text = Text() final_report_text.append("Penetration test summary", style="bold #60a5fa") final_report_panel = Panel( Text.assemble( final_report_text, "\n\n", tracer.final_scan_result, ), title="[bold white]STRIX", title_align="left", border_style="#60a5fa", padding=(1, 2), ) console.print(final_report_panel) console.print()