338 lines
12 KiB
Python
338 lines
12 KiB
Python
import logging
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
from uuid import uuid4
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Callable
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_global_tracer: Optional["Tracer"] = None
|
|
|
|
|
|
def get_global_tracer() -> Optional["Tracer"]:
|
|
return _global_tracer
|
|
|
|
|
|
def set_global_tracer(tracer: "Tracer") -> None:
|
|
global _global_tracer # noqa: PLW0603
|
|
_global_tracer = tracer
|
|
|
|
|
|
class Tracer:
|
|
def __init__(self, run_name: str | None = None):
|
|
self.run_name = run_name
|
|
self.run_id = run_name or f"run-{uuid4().hex[:8]}"
|
|
self.start_time = datetime.now(UTC).isoformat()
|
|
self.end_time: str | None = None
|
|
|
|
self.agents: dict[str, dict[str, Any]] = {}
|
|
self.tool_executions: dict[int, dict[str, Any]] = {}
|
|
self.chat_messages: list[dict[str, Any]] = []
|
|
|
|
self.vulnerability_reports: list[dict[str, Any]] = []
|
|
self.final_scan_result: str | None = None
|
|
|
|
self.scan_results: dict[str, Any] | None = None
|
|
self.scan_config: dict[str, Any] | None = None
|
|
self.run_metadata: dict[str, Any] = {
|
|
"run_id": self.run_id,
|
|
"run_name": self.run_name,
|
|
"start_time": self.start_time,
|
|
"end_time": None,
|
|
"targets": [],
|
|
"status": "running",
|
|
}
|
|
self._run_dir: Path | None = None
|
|
self._next_execution_id = 1
|
|
self._next_message_id = 1
|
|
self._saved_vuln_ids: set[str] = set()
|
|
|
|
self.vulnerability_found_callback: Callable[[str, str, str, str], None] | None = None
|
|
|
|
def set_run_name(self, run_name: str) -> None:
|
|
self.run_name = run_name
|
|
self.run_id = run_name
|
|
|
|
def get_run_dir(self) -> Path:
|
|
if self._run_dir is None:
|
|
runs_dir = Path.cwd() / "strix_runs"
|
|
runs_dir.mkdir(exist_ok=True)
|
|
|
|
run_dir_name = self.run_name if self.run_name else self.run_id
|
|
self._run_dir = runs_dir / run_dir_name
|
|
self._run_dir.mkdir(exist_ok=True)
|
|
|
|
return self._run_dir
|
|
|
|
def add_vulnerability_report(
|
|
self,
|
|
title: str,
|
|
content: str,
|
|
severity: str,
|
|
) -> str:
|
|
report_id = f"vuln-{len(self.vulnerability_reports) + 1:04d}"
|
|
|
|
report = {
|
|
"id": report_id,
|
|
"title": title.strip(),
|
|
"content": content.strip(),
|
|
"severity": severity.lower().strip(),
|
|
"timestamp": datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S UTC"),
|
|
}
|
|
|
|
self.vulnerability_reports.append(report)
|
|
logger.info(f"Added vulnerability report: {report_id} - {title}")
|
|
|
|
if self.vulnerability_found_callback:
|
|
self.vulnerability_found_callback(
|
|
report_id, title.strip(), content.strip(), severity.lower().strip()
|
|
)
|
|
|
|
self.save_run_data()
|
|
return report_id
|
|
|
|
def set_final_scan_result(
|
|
self,
|
|
content: str,
|
|
success: bool = True,
|
|
) -> None:
|
|
self.final_scan_result = content.strip()
|
|
|
|
self.scan_results = {
|
|
"scan_completed": True,
|
|
"content": content,
|
|
"success": success,
|
|
}
|
|
|
|
logger.info(f"Set final scan result: success={success}")
|
|
self.save_run_data(mark_complete=True)
|
|
|
|
def log_agent_creation(
|
|
self, agent_id: str, name: str, task: str, parent_id: str | None = None
|
|
) -> None:
|
|
agent_data: dict[str, Any] = {
|
|
"id": agent_id,
|
|
"name": name,
|
|
"task": task,
|
|
"status": "running",
|
|
"parent_id": parent_id,
|
|
"created_at": datetime.now(UTC).isoformat(),
|
|
"updated_at": datetime.now(UTC).isoformat(),
|
|
"tool_executions": [],
|
|
}
|
|
|
|
self.agents[agent_id] = agent_data
|
|
|
|
def log_chat_message(
|
|
self,
|
|
content: str,
|
|
role: str,
|
|
agent_id: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> int:
|
|
message_id = self._next_message_id
|
|
self._next_message_id += 1
|
|
|
|
message_data = {
|
|
"message_id": message_id,
|
|
"content": content,
|
|
"role": role,
|
|
"agent_id": agent_id,
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"metadata": metadata or {},
|
|
}
|
|
|
|
self.chat_messages.append(message_data)
|
|
return message_id
|
|
|
|
def log_tool_execution_start(self, agent_id: str, tool_name: str, args: dict[str, Any]) -> int:
|
|
execution_id = self._next_execution_id
|
|
self._next_execution_id += 1
|
|
|
|
now = datetime.now(UTC).isoformat()
|
|
execution_data = {
|
|
"execution_id": execution_id,
|
|
"agent_id": agent_id,
|
|
"tool_name": tool_name,
|
|
"args": args,
|
|
"status": "running",
|
|
"result": None,
|
|
"timestamp": now,
|
|
"started_at": now,
|
|
"completed_at": None,
|
|
}
|
|
|
|
self.tool_executions[execution_id] = execution_data
|
|
|
|
if agent_id in self.agents:
|
|
self.agents[agent_id]["tool_executions"].append(execution_id)
|
|
|
|
return execution_id
|
|
|
|
def update_tool_execution(
|
|
self, execution_id: int, status: str, result: Any | None = None
|
|
) -> None:
|
|
if execution_id in self.tool_executions:
|
|
self.tool_executions[execution_id]["status"] = status
|
|
self.tool_executions[execution_id]["result"] = result
|
|
self.tool_executions[execution_id]["completed_at"] = datetime.now(UTC).isoformat()
|
|
|
|
def update_agent_status(
|
|
self, agent_id: str, status: str, error_message: str | None = None
|
|
) -> None:
|
|
if agent_id in self.agents:
|
|
self.agents[agent_id]["status"] = status
|
|
self.agents[agent_id]["updated_at"] = datetime.now(UTC).isoformat()
|
|
if error_message:
|
|
self.agents[agent_id]["error_message"] = error_message
|
|
|
|
def set_scan_config(self, config: dict[str, Any]) -> None:
|
|
self.scan_config = config
|
|
self.run_metadata.update(
|
|
{
|
|
"targets": config.get("targets", []),
|
|
"user_instructions": config.get("user_instructions", ""),
|
|
"max_iterations": config.get("max_iterations", 200),
|
|
}
|
|
)
|
|
self.get_run_dir()
|
|
|
|
def save_run_data(self, mark_complete: bool = False) -> None:
|
|
try:
|
|
run_dir = self.get_run_dir()
|
|
if mark_complete:
|
|
self.end_time = datetime.now(UTC).isoformat()
|
|
|
|
if self.final_scan_result:
|
|
penetration_test_report_file = run_dir / "penetration_test_report.md"
|
|
with penetration_test_report_file.open("w", encoding="utf-8") as f:
|
|
f.write("# Security Penetration Test Report\n\n")
|
|
f.write(
|
|
f"**Generated:** {datetime.now(UTC).strftime('%Y-%m-%d %H:%M:%S UTC')}\n\n"
|
|
)
|
|
f.write(f"{self.final_scan_result}\n")
|
|
logger.info(
|
|
f"Saved final penetration test report to: {penetration_test_report_file}"
|
|
)
|
|
|
|
if self.vulnerability_reports:
|
|
vuln_dir = run_dir / "vulnerabilities"
|
|
vuln_dir.mkdir(exist_ok=True)
|
|
|
|
new_reports = [
|
|
report
|
|
for report in self.vulnerability_reports
|
|
if report["id"] not in self._saved_vuln_ids
|
|
]
|
|
|
|
for report in new_reports:
|
|
vuln_file = vuln_dir / f"{report['id']}.md"
|
|
with vuln_file.open("w", encoding="utf-8") as f:
|
|
f.write(f"# {report['title']}\n\n")
|
|
f.write(f"**ID:** {report['id']}\n")
|
|
f.write(f"**Severity:** {report['severity'].upper()}\n")
|
|
f.write(f"**Found:** {report['timestamp']}\n\n")
|
|
f.write("## Description\n\n")
|
|
f.write(f"{report['content']}\n")
|
|
self._saved_vuln_ids.add(report["id"])
|
|
|
|
if self.vulnerability_reports:
|
|
severity_order = {"critical": 0, "high": 1, "medium": 2, "low": 3, "info": 4}
|
|
sorted_reports = sorted(
|
|
self.vulnerability_reports,
|
|
key=lambda x: (severity_order.get(x["severity"], 5), x["timestamp"]),
|
|
)
|
|
|
|
vuln_csv_file = run_dir / "vulnerabilities.csv"
|
|
with vuln_csv_file.open("w", encoding="utf-8", newline="") as f:
|
|
import csv
|
|
|
|
fieldnames = ["id", "title", "severity", "timestamp", "file"]
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
|
|
for report in sorted_reports:
|
|
writer.writerow(
|
|
{
|
|
"id": report["id"],
|
|
"title": report["title"],
|
|
"severity": report["severity"].upper(),
|
|
"timestamp": report["timestamp"],
|
|
"file": f"vulnerabilities/{report['id']}.md",
|
|
}
|
|
)
|
|
|
|
if new_reports:
|
|
logger.info(
|
|
f"Saved {len(new_reports)} new vulnerability report(s) to: {vuln_dir}"
|
|
)
|
|
logger.info(f"Updated vulnerability index: {vuln_csv_file}")
|
|
|
|
logger.info(f"📊 Essential scan data saved to: {run_dir}")
|
|
|
|
except (OSError, RuntimeError):
|
|
logger.exception("Failed to save scan data")
|
|
|
|
def _calculate_duration(self) -> float:
|
|
try:
|
|
start = datetime.fromisoformat(self.start_time.replace("Z", "+00:00"))
|
|
if self.end_time:
|
|
end = datetime.fromisoformat(self.end_time.replace("Z", "+00:00"))
|
|
return (end - start).total_seconds()
|
|
except (ValueError, TypeError):
|
|
pass
|
|
return 0.0
|
|
|
|
def get_agent_tools(self, agent_id: str) -> list[dict[str, Any]]:
|
|
return [
|
|
exec_data
|
|
for exec_data in self.tool_executions.values()
|
|
if exec_data.get("agent_id") == agent_id
|
|
]
|
|
|
|
def get_real_tool_count(self) -> int:
|
|
return sum(
|
|
1
|
|
for exec_data in self.tool_executions.values()
|
|
if exec_data.get("tool_name") not in ["scan_start_info", "subagent_start_info"]
|
|
)
|
|
|
|
def get_total_llm_stats(self) -> dict[str, Any]:
|
|
from strix.tools.agents_graph.agents_graph_actions import _agent_instances
|
|
|
|
total_stats = {
|
|
"input_tokens": 0,
|
|
"output_tokens": 0,
|
|
"cached_tokens": 0,
|
|
"cache_creation_tokens": 0,
|
|
"cost": 0.0,
|
|
"requests": 0,
|
|
"failed_requests": 0,
|
|
}
|
|
|
|
for agent_instance in _agent_instances.values():
|
|
if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"):
|
|
agent_stats = agent_instance.llm._total_stats
|
|
total_stats["input_tokens"] += agent_stats.input_tokens
|
|
total_stats["output_tokens"] += agent_stats.output_tokens
|
|
total_stats["cached_tokens"] += agent_stats.cached_tokens
|
|
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
|
|
total_stats["cost"] += agent_stats.cost
|
|
total_stats["requests"] += agent_stats.requests
|
|
total_stats["failed_requests"] += agent_stats.failed_requests
|
|
|
|
total_stats["cost"] = round(total_stats["cost"], 4)
|
|
|
|
return {
|
|
"total": total_stats,
|
|
"total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"],
|
|
}
|
|
|
|
def cleanup(self) -> None:
|
|
self.save_run_data(mark_complete=True)
|