diff --git a/strix/interface/cli.py b/strix/interface/cli.py index 582f811..b1e5452 100644 --- a/strix/interface/cli.py +++ b/strix/interface/cli.py @@ -14,7 +14,11 @@ from strix.agents.StrixAgent import StrixAgent from strix.llm.config import LLMConfig from strix.telemetry.tracer import Tracer, set_global_tracer -from .utils import build_final_stats_text, build_live_stats_text, get_severity_color +from .utils import ( + build_final_stats_text, + build_live_stats_text, + format_vulnerability_report, +) async def run_cli(args: Any) -> None: # noqa: PLR0915 @@ -88,28 +92,14 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 tracer = Tracer(args.run_name) tracer.set_scan_config(scan_config) - def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None: - severity_color = get_severity_color(severity.lower()) + def display_vulnerability(report: dict[str, Any]) -> None: + report_id = report.get("id", "unknown") - vuln_text = Text() - vuln_text.append("🐞 ", style="bold red") - vuln_text.append("VULNERABILITY FOUND", style="bold red") - vuln_text.append(" • ", style="dim white") - vuln_text.append(title, style="bold white") - - severity_text = Text() - severity_text.append("Severity: ", style="dim white") - severity_text.append(severity.upper(), style=f"bold {severity_color}") + vuln_text = format_vulnerability_report(report) vuln_panel = Panel( - Text.assemble( - vuln_text, - "\n\n", - severity_text, - "\n\n", - content, - ), - title=f"[bold red]🔍 {report_id.upper()}", + vuln_text, + title=f"[bold red]{report_id.upper()}", title_align="left", border_style="red", padding=(1, 2), diff --git a/strix/interface/utils.py b/strix/interface/utils.py index bd128a7..1f40fdf 100644 --- a/strix/interface/utils.py +++ b/strix/interface/utils.py @@ -38,6 +38,165 @@ def get_severity_color(severity: str) -> str: return severity_colors.get(severity, "#6b7280") +def get_cvss_color(cvss_score: float) -> str: + if cvss_score >= 9.0: + return "#dc2626" + if cvss_score >= 7.0: + return "#ea580c" + if cvss_score >= 4.0: + return "#d97706" + if cvss_score >= 0.1: + return "#65a30d" + return "#6b7280" + + +def format_vulnerability_report(report: dict[str, Any]) -> Text: # noqa: PLR0912, PLR0915 + """Format a vulnerability report for CLI display with all rich fields.""" + field_style = "bold #4ade80" + + text = Text() + + title = report.get("title", "") + if title: + text.append("Vulnerability Report", style="bold #ea580c") + text.append("\n\n") + text.append("Title: ", style=field_style) + text.append(title) + + severity = report.get("severity", "") + if severity: + text.append("\n\n") + text.append("Severity: ", style=field_style) + severity_color = get_severity_color(severity.lower()) + text.append(severity.upper(), style=f"bold {severity_color}") + + cvss = report.get("cvss") + if cvss is not None: + text.append("\n\n") + text.append("CVSS Score: ", style=field_style) + cvss_color = get_cvss_color(cvss) + text.append(f"{cvss:.1f}", style=f"bold {cvss_color}") + + target = report.get("target") + if target: + text.append("\n\n") + text.append("Target: ", style=field_style) + text.append(target) + + endpoint = report.get("endpoint") + if endpoint: + text.append("\n\n") + text.append("Endpoint: ", style=field_style) + text.append(endpoint) + + method = report.get("method") + if method: + text.append("\n\n") + text.append("Method: ", style=field_style) + text.append(method) + + cve = report.get("cve") + if cve: + text.append("\n\n") + text.append("CVE: ", style=field_style) + text.append(cve) + + cvss_breakdown = report.get("cvss_breakdown", {}) + if cvss_breakdown: + text.append("\n\n") + cvss_parts = [] + if cvss_breakdown.get("attack_vector"): + cvss_parts.append(f"AV:{cvss_breakdown['attack_vector']}") + if cvss_breakdown.get("attack_complexity"): + cvss_parts.append(f"AC:{cvss_breakdown['attack_complexity']}") + if cvss_breakdown.get("privileges_required"): + cvss_parts.append(f"PR:{cvss_breakdown['privileges_required']}") + if cvss_breakdown.get("user_interaction"): + cvss_parts.append(f"UI:{cvss_breakdown['user_interaction']}") + if cvss_breakdown.get("scope"): + cvss_parts.append(f"S:{cvss_breakdown['scope']}") + if cvss_breakdown.get("confidentiality"): + cvss_parts.append(f"C:{cvss_breakdown['confidentiality']}") + if cvss_breakdown.get("integrity"): + cvss_parts.append(f"I:{cvss_breakdown['integrity']}") + if cvss_breakdown.get("availability"): + cvss_parts.append(f"A:{cvss_breakdown['availability']}") + if cvss_parts: + text.append("CVSS Vector: ", style=field_style) + text.append("/".join(cvss_parts), style="dim") + + description = report.get("description") + if description: + text.append("\n\n") + text.append("Description", style=field_style) + text.append("\n") + text.append(description) + + impact = report.get("impact") + if impact: + text.append("\n\n") + text.append("Impact", style=field_style) + text.append("\n") + text.append(impact) + + technical_analysis = report.get("technical_analysis") + if technical_analysis: + text.append("\n\n") + text.append("Technical Analysis", style=field_style) + text.append("\n") + text.append(technical_analysis) + + poc_description = report.get("poc_description") + if poc_description: + text.append("\n\n") + text.append("PoC Description", style=field_style) + text.append("\n") + text.append(poc_description) + + poc_script_code = report.get("poc_script_code") + if poc_script_code: + text.append("\n\n") + text.append("PoC Code", style=field_style) + text.append("\n") + text.append(poc_script_code, style="dim") + + code_file = report.get("code_file") + if code_file: + text.append("\n\n") + text.append("Code File: ", style=field_style) + text.append(code_file) + + code_before = report.get("code_before") + if code_before: + text.append("\n\n") + text.append("Code Before", style=field_style) + text.append("\n") + text.append(code_before, style="dim") + + code_after = report.get("code_after") + if code_after: + text.append("\n\n") + text.append("Code After", style=field_style) + text.append("\n") + text.append(code_after, style="dim") + + code_diff = report.get("code_diff") + if code_diff: + text.append("\n\n") + text.append("Code Diff", style=field_style) + text.append("\n") + text.append(code_diff, style="dim") + + remediation_steps = report.get("remediation_steps") + if remediation_steps: + text.append("\n\n") + text.append("Remediation", style=field_style) + text.append("\n") + text.append(remediation_steps) + + return text + + def _build_vulnerability_stats(stats_text: Text, tracer: Any) -> None: """Build vulnerability section of stats text.""" vuln_count = len(tracer.vulnerability_reports) diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index 26890ad..e848b9e 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -54,7 +54,7 @@ class Tracer: self._next_message_id = 1 self._saved_vuln_ids: set[str] = set() - self.vulnerability_found_callback: Callable[[str, str, str, str], None] | None = None + self.vulnerability_found_callback: Callable[[dict[str, Any]], None] | None = None def set_run_name(self, run_name: str) -> None: self.run_name = run_name @@ -138,9 +138,7 @@ class Tracer: logger.info(f"Added vulnerability report: {report_id} - {title}") if self.vulnerability_found_callback: - self.vulnerability_found_callback( - report_id, title.strip(), description or "", severity.lower().strip() - ) + self.vulnerability_found_callback(report) self.save_run_data() return report_id