import logging from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING, Any, Optional from uuid import uuid4 from strix.telemetry import posthog if TYPE_CHECKING: from collections.abc import Callable logger = logging.getLogger(__name__) _global_tracer: Optional["Tracer"] = None def get_global_tracer() -> Optional["Tracer"]: return _global_tracer def set_global_tracer(tracer: "Tracer") -> None: global _global_tracer # noqa: PLW0603 _global_tracer = tracer class Tracer: def __init__(self, run_name: str | None = None): self.run_name = run_name self.run_id = run_name or f"run-{uuid4().hex[:8]}" self.start_time = datetime.now(UTC).isoformat() self.end_time: str | None = None self.agents: dict[str, dict[str, Any]] = {} self.tool_executions: dict[int, dict[str, Any]] = {} self.chat_messages: list[dict[str, Any]] = [] self.streaming_content: dict[str, str] = {} self.interrupted_content: dict[str, str] = {} self.vulnerability_reports: list[dict[str, Any]] = [] self.final_scan_result: str | None = None self.scan_results: dict[str, Any] | None = None self.scan_config: dict[str, Any] | None = None self.run_metadata: dict[str, Any] = { "run_id": self.run_id, "run_name": self.run_name, "start_time": self.start_time, "end_time": None, "targets": [], "status": "running", } self._run_dir: Path | None = None self._next_execution_id = 1 self._next_message_id = 1 self._saved_vuln_ids: set[str] = set() self.vulnerability_found_callback: Callable[[dict[str, Any]], None] | None = None def set_run_name(self, run_name: str) -> None: self.run_name = run_name self.run_id = run_name def get_run_dir(self) -> Path: if self._run_dir is None: runs_dir = Path.cwd() / "strix_runs" runs_dir.mkdir(exist_ok=True) run_dir_name = self.run_name if self.run_name else self.run_id self._run_dir = runs_dir / run_dir_name self._run_dir.mkdir(exist_ok=True) return self._run_dir def add_vulnerability_report( # noqa: PLR0912 self, title: str, severity: str, description: str | None = None, impact: str | None = None, target: str | None = None, technical_analysis: str | None = None, poc_description: str | None = None, poc_script_code: str | None = None, remediation_steps: str | None = None, cvss: float | None = None, cvss_breakdown: dict[str, str] | None = None, endpoint: str | None = None, method: str | None = None, cve: str | None = None, code_file: str | None = None, code_before: str | None = None, code_after: str | None = None, code_diff: str | None = None, ) -> str: report_id = f"vuln-{len(self.vulnerability_reports) + 1:04d}" report: dict[str, Any] = { "id": report_id, "title": title.strip(), "severity": severity.lower().strip(), "timestamp": datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S UTC"), } if description: report["description"] = description.strip() if impact: report["impact"] = impact.strip() if target: report["target"] = target.strip() if technical_analysis: report["technical_analysis"] = technical_analysis.strip() if poc_description: report["poc_description"] = poc_description.strip() if poc_script_code: report["poc_script_code"] = poc_script_code.strip() if remediation_steps: report["remediation_steps"] = remediation_steps.strip() if cvss is not None: report["cvss"] = cvss if cvss_breakdown: report["cvss_breakdown"] = cvss_breakdown if endpoint: report["endpoint"] = endpoint.strip() if method: report["method"] = method.strip() if cve: report["cve"] = cve.strip() if code_file: report["code_file"] = code_file.strip() if code_before: report["code_before"] = code_before.strip() if code_after: report["code_after"] = code_after.strip() if code_diff: report["code_diff"] = code_diff.strip() self.vulnerability_reports.append(report) logger.info(f"Added vulnerability report: {report_id} - {title}") posthog.finding(severity) if self.vulnerability_found_callback: self.vulnerability_found_callback(report) self.save_run_data() return report_id def get_existing_vulnerabilities(self) -> list[dict[str, Any]]: return list(self.vulnerability_reports) def update_scan_final_fields( self, executive_summary: str, methodology: str, technical_analysis: str, recommendations: str, ) -> None: self.scan_results = { "scan_completed": True, "executive_summary": executive_summary.strip(), "methodology": methodology.strip(), "technical_analysis": technical_analysis.strip(), "recommendations": recommendations.strip(), "success": True, } self.final_scan_result = f"""# Executive Summary {executive_summary.strip()} # Methodology {methodology.strip()} # Technical Analysis {technical_analysis.strip()} # Recommendations {recommendations.strip()} """ logger.info("Updated scan final fields") self.save_run_data(mark_complete=True) posthog.end(self, exit_reason="finished_by_tool") def log_agent_creation( self, agent_id: str, name: str, task: str, parent_id: str | None = None ) -> None: agent_data: dict[str, Any] = { "id": agent_id, "name": name, "task": task, "status": "running", "parent_id": parent_id, "created_at": datetime.now(UTC).isoformat(), "updated_at": datetime.now(UTC).isoformat(), "tool_executions": [], } self.agents[agent_id] = agent_data def log_chat_message( self, content: str, role: str, agent_id: str | None = None, metadata: dict[str, Any] | None = None, ) -> int: message_id = self._next_message_id self._next_message_id += 1 message_data = { "message_id": message_id, "content": content, "role": role, "agent_id": agent_id, "timestamp": datetime.now(UTC).isoformat(), "metadata": metadata or {}, } self.chat_messages.append(message_data) return message_id def log_tool_execution_start(self, agent_id: str, tool_name: str, args: dict[str, Any]) -> int: execution_id = self._next_execution_id self._next_execution_id += 1 now = datetime.now(UTC).isoformat() execution_data = { "execution_id": execution_id, "agent_id": agent_id, "tool_name": tool_name, "args": args, "status": "running", "result": None, "timestamp": now, "started_at": now, "completed_at": None, } self.tool_executions[execution_id] = execution_data if agent_id in self.agents: self.agents[agent_id]["tool_executions"].append(execution_id) return execution_id def update_tool_execution( self, execution_id: int, status: str, result: Any | None = None ) -> None: if execution_id in self.tool_executions: self.tool_executions[execution_id]["status"] = status self.tool_executions[execution_id]["result"] = result self.tool_executions[execution_id]["completed_at"] = datetime.now(UTC).isoformat() def update_agent_status( self, agent_id: str, status: str, error_message: str | None = None ) -> None: if agent_id in self.agents: self.agents[agent_id]["status"] = status self.agents[agent_id]["updated_at"] = datetime.now(UTC).isoformat() if error_message: self.agents[agent_id]["error_message"] = error_message def set_scan_config(self, config: dict[str, Any]) -> None: self.scan_config = config self.run_metadata.update( { "targets": config.get("targets", []), "user_instructions": config.get("user_instructions", ""), "max_iterations": config.get("max_iterations", 200), } ) self.get_run_dir() def save_run_data(self, mark_complete: bool = False) -> None: # noqa: PLR0912, PLR0915 try: run_dir = self.get_run_dir() if mark_complete: self.end_time = datetime.now(UTC).isoformat() if self.final_scan_result: penetration_test_report_file = run_dir / "penetration_test_report.md" with penetration_test_report_file.open("w", encoding="utf-8") as f: f.write("# Security Penetration Test Report\n\n") f.write( f"**Generated:** {datetime.now(UTC).strftime('%Y-%m-%d %H:%M:%S UTC')}\n\n" ) f.write(f"{self.final_scan_result}\n") logger.info( f"Saved final penetration test report to: {penetration_test_report_file}" ) if self.vulnerability_reports: vuln_dir = run_dir / "vulnerabilities" vuln_dir.mkdir(exist_ok=True) new_reports = [ report for report in self.vulnerability_reports if report["id"] not in self._saved_vuln_ids ] severity_order = {"critical": 0, "high": 1, "medium": 2, "low": 3, "info": 4} sorted_reports = sorted( self.vulnerability_reports, key=lambda x: (severity_order.get(x["severity"], 5), x["timestamp"]), ) for report in new_reports: vuln_file = vuln_dir / f"{report['id']}.md" with vuln_file.open("w", encoding="utf-8") as f: f.write(f"# {report.get('title', 'Untitled Vulnerability')}\n\n") f.write(f"**ID:** {report.get('id', 'unknown')}\n") f.write(f"**Severity:** {report.get('severity', 'unknown').upper()}\n") f.write(f"**Found:** {report.get('timestamp', 'unknown')}\n") metadata_fields: list[tuple[str, Any]] = [ ("Target", report.get("target")), ("Endpoint", report.get("endpoint")), ("Method", report.get("method")), ("CVE", report.get("cve")), ] cvss_score = report.get("cvss") if cvss_score is not None: metadata_fields.append(("CVSS", cvss_score)) for label, value in metadata_fields: if value: f.write(f"**{label}:** {value}\n") f.write("\n## Description\n\n") desc = report.get("description") or "No description provided." f.write(f"{desc}\n\n") if report.get("impact"): f.write("## Impact\n\n") f.write(f"{report['impact']}\n\n") if report.get("technical_analysis"): f.write("## Technical Analysis\n\n") f.write(f"{report['technical_analysis']}\n\n") if report.get("poc_description") or report.get("poc_script_code"): f.write("## Proof of Concept\n\n") if report.get("poc_description"): f.write(f"{report['poc_description']}\n\n") if report.get("poc_script_code"): f.write("```\n") f.write(f"{report['poc_script_code']}\n") f.write("```\n\n") if report.get("code_file") or report.get("code_diff"): f.write("## Code Analysis\n\n") if report.get("code_file"): f.write(f"**File:** {report['code_file']}\n\n") if report.get("code_diff"): f.write("**Changes:**\n") f.write("```diff\n") f.write(f"{report['code_diff']}\n") f.write("```\n\n") if report.get("remediation_steps"): f.write("## Remediation\n\n") f.write(f"{report['remediation_steps']}\n\n") self._saved_vuln_ids.add(report["id"]) vuln_csv_file = run_dir / "vulnerabilities.csv" with vuln_csv_file.open("w", encoding="utf-8", newline="") as f: import csv fieldnames = ["id", "title", "severity", "timestamp", "file"] writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for report in sorted_reports: writer.writerow( { "id": report["id"], "title": report["title"], "severity": report["severity"].upper(), "timestamp": report["timestamp"], "file": f"vulnerabilities/{report['id']}.md", } ) if new_reports: logger.info( f"Saved {len(new_reports)} new vulnerability report(s) to: {vuln_dir}" ) logger.info(f"Updated vulnerability index: {vuln_csv_file}") logger.info(f"📊 Essential scan data saved to: {run_dir}") except (OSError, RuntimeError): logger.exception("Failed to save scan data") def _calculate_duration(self) -> float: try: start = datetime.fromisoformat(self.start_time.replace("Z", "+00:00")) if self.end_time: end = datetime.fromisoformat(self.end_time.replace("Z", "+00:00")) return (end - start).total_seconds() except (ValueError, TypeError): pass return 0.0 def get_agent_tools(self, agent_id: str) -> list[dict[str, Any]]: return [ exec_data for exec_data in list(self.tool_executions.values()) if exec_data.get("agent_id") == agent_id ] def get_real_tool_count(self) -> int: return sum( 1 for exec_data in list(self.tool_executions.values()) if exec_data.get("tool_name") not in ["scan_start_info", "subagent_start_info"] ) def get_total_llm_stats(self) -> dict[str, Any]: from strix.tools.agents_graph.agents_graph_actions import _agent_instances total_stats = { "input_tokens": 0, "output_tokens": 0, "cached_tokens": 0, "cache_creation_tokens": 0, "cost": 0.0, "requests": 0, "failed_requests": 0, } for agent_instance in _agent_instances.values(): if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"): agent_stats = agent_instance.llm._total_stats total_stats["input_tokens"] += agent_stats.input_tokens total_stats["output_tokens"] += agent_stats.output_tokens total_stats["cached_tokens"] += agent_stats.cached_tokens total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens total_stats["cost"] += agent_stats.cost total_stats["requests"] += agent_stats.requests total_stats["failed_requests"] += agent_stats.failed_requests total_stats["cost"] = round(total_stats["cost"], 4) return { "total": total_stats, "total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"], } def update_streaming_content(self, agent_id: str, content: str) -> None: self.streaming_content[agent_id] = content def clear_streaming_content(self, agent_id: str) -> None: self.streaming_content.pop(agent_id, None) def get_streaming_content(self, agent_id: str) -> str | None: return self.streaming_content.get(agent_id) def finalize_streaming_as_interrupted(self, agent_id: str) -> str | None: content = self.streaming_content.pop(agent_id, None) if content and content.strip(): self.interrupted_content[agent_id] = content self.log_chat_message( content=content, role="assistant", agent_id=agent_id, metadata={"interrupted": True}, ) return content return self.interrupted_content.pop(agent_id, None) def cleanup(self) -> None: self.save_run_data(mark_complete=True)