1460 lines
49 KiB
Python
1460 lines
49 KiB
Python
import ipaddress
|
|
import json
|
|
import os
|
|
import re
|
|
import secrets
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib.error import HTTPError, URLError
|
|
from urllib.parse import urlparse
|
|
from urllib.request import Request, urlopen
|
|
|
|
import docker
|
|
from docker.errors import DockerException, ImageNotFound
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.text import Text
|
|
|
|
|
|
# Token formatting utilities
|
|
def format_token_count(count: float) -> str:
|
|
count = int(count)
|
|
if count >= 1_000_000:
|
|
return f"{count / 1_000_000:.1f}M"
|
|
if count >= 1_000:
|
|
return f"{count / 1_000:.1f}K"
|
|
return str(count)
|
|
|
|
|
|
# Display utilities
|
|
def get_severity_color(severity: str) -> str:
|
|
severity_colors = {
|
|
"critical": "#dc2626",
|
|
"high": "#ea580c",
|
|
"medium": "#d97706",
|
|
"low": "#65a30d",
|
|
"info": "#0284c7",
|
|
}
|
|
return severity_colors.get(severity, "#6b7280")
|
|
|
|
|
|
def get_cvss_color(cvss_score: float) -> str:
|
|
if cvss_score >= 9.0:
|
|
return "#dc2626"
|
|
if cvss_score >= 7.0:
|
|
return "#ea580c"
|
|
if cvss_score >= 4.0:
|
|
return "#d97706"
|
|
if cvss_score >= 0.1:
|
|
return "#65a30d"
|
|
return "#6b7280"
|
|
|
|
|
|
def format_vulnerability_report(report: dict[str, Any]) -> Text: # noqa: PLR0912, PLR0915
|
|
"""Format a vulnerability report for CLI display with all rich fields."""
|
|
field_style = "bold #4ade80"
|
|
|
|
text = Text()
|
|
|
|
title = report.get("title", "")
|
|
if title:
|
|
text.append("Vulnerability Report", style="bold #ea580c")
|
|
text.append("\n\n")
|
|
text.append("Title: ", style=field_style)
|
|
text.append(title)
|
|
|
|
severity = report.get("severity", "")
|
|
if severity:
|
|
text.append("\n\n")
|
|
text.append("Severity: ", style=field_style)
|
|
severity_color = get_severity_color(severity.lower())
|
|
text.append(severity.upper(), style=f"bold {severity_color}")
|
|
|
|
cvss = report.get("cvss")
|
|
if cvss is not None:
|
|
text.append("\n\n")
|
|
text.append("CVSS Score: ", style=field_style)
|
|
cvss_color = get_cvss_color(cvss)
|
|
text.append(f"{cvss:.1f}", style=f"bold {cvss_color}")
|
|
|
|
target = report.get("target")
|
|
if target:
|
|
text.append("\n\n")
|
|
text.append("Target: ", style=field_style)
|
|
text.append(target)
|
|
|
|
endpoint = report.get("endpoint")
|
|
if endpoint:
|
|
text.append("\n\n")
|
|
text.append("Endpoint: ", style=field_style)
|
|
text.append(endpoint)
|
|
|
|
method = report.get("method")
|
|
if method:
|
|
text.append("\n\n")
|
|
text.append("Method: ", style=field_style)
|
|
text.append(method)
|
|
|
|
cve = report.get("cve")
|
|
if cve:
|
|
text.append("\n\n")
|
|
text.append("CVE: ", style=field_style)
|
|
text.append(cve)
|
|
|
|
cvss_breakdown = report.get("cvss_breakdown", {})
|
|
if cvss_breakdown:
|
|
text.append("\n\n")
|
|
cvss_parts = []
|
|
if cvss_breakdown.get("attack_vector"):
|
|
cvss_parts.append(f"AV:{cvss_breakdown['attack_vector']}")
|
|
if cvss_breakdown.get("attack_complexity"):
|
|
cvss_parts.append(f"AC:{cvss_breakdown['attack_complexity']}")
|
|
if cvss_breakdown.get("privileges_required"):
|
|
cvss_parts.append(f"PR:{cvss_breakdown['privileges_required']}")
|
|
if cvss_breakdown.get("user_interaction"):
|
|
cvss_parts.append(f"UI:{cvss_breakdown['user_interaction']}")
|
|
if cvss_breakdown.get("scope"):
|
|
cvss_parts.append(f"S:{cvss_breakdown['scope']}")
|
|
if cvss_breakdown.get("confidentiality"):
|
|
cvss_parts.append(f"C:{cvss_breakdown['confidentiality']}")
|
|
if cvss_breakdown.get("integrity"):
|
|
cvss_parts.append(f"I:{cvss_breakdown['integrity']}")
|
|
if cvss_breakdown.get("availability"):
|
|
cvss_parts.append(f"A:{cvss_breakdown['availability']}")
|
|
if cvss_parts:
|
|
text.append("CVSS Vector: ", style=field_style)
|
|
text.append("/".join(cvss_parts), style="dim")
|
|
|
|
description = report.get("description")
|
|
if description:
|
|
text.append("\n\n")
|
|
text.append("Description", style=field_style)
|
|
text.append("\n")
|
|
text.append(description)
|
|
|
|
impact = report.get("impact")
|
|
if impact:
|
|
text.append("\n\n")
|
|
text.append("Impact", style=field_style)
|
|
text.append("\n")
|
|
text.append(impact)
|
|
|
|
technical_analysis = report.get("technical_analysis")
|
|
if technical_analysis:
|
|
text.append("\n\n")
|
|
text.append("Technical Analysis", style=field_style)
|
|
text.append("\n")
|
|
text.append(technical_analysis)
|
|
|
|
poc_description = report.get("poc_description")
|
|
if poc_description:
|
|
text.append("\n\n")
|
|
text.append("PoC Description", style=field_style)
|
|
text.append("\n")
|
|
text.append(poc_description)
|
|
|
|
poc_script_code = report.get("poc_script_code")
|
|
if poc_script_code:
|
|
text.append("\n\n")
|
|
text.append("PoC Code", style=field_style)
|
|
text.append("\n")
|
|
text.append(poc_script_code, style="dim")
|
|
|
|
code_locations = report.get("code_locations")
|
|
if code_locations:
|
|
text.append("\n\n")
|
|
text.append("Code Locations", style=field_style)
|
|
for i, loc in enumerate(code_locations):
|
|
text.append("\n\n")
|
|
text.append(f" Location {i + 1}: ", style="dim")
|
|
text.append(loc.get("file", "unknown"), style="bold")
|
|
start = loc.get("start_line")
|
|
end = loc.get("end_line")
|
|
if start is not None:
|
|
if end and end != start:
|
|
text.append(f":{start}-{end}")
|
|
else:
|
|
text.append(f":{start}")
|
|
if loc.get("label"):
|
|
text.append(f"\n {loc['label']}", style="italic dim")
|
|
if loc.get("snippet"):
|
|
text.append("\n ")
|
|
text.append(loc["snippet"], style="dim")
|
|
if loc.get("fix_before") or loc.get("fix_after"):
|
|
text.append("\n Fix:")
|
|
if loc.get("fix_before"):
|
|
text.append("\n - ", style="dim")
|
|
text.append(loc["fix_before"], style="dim")
|
|
if loc.get("fix_after"):
|
|
text.append("\n + ", style="dim")
|
|
text.append(loc["fix_after"], style="dim")
|
|
|
|
remediation_steps = report.get("remediation_steps")
|
|
if remediation_steps:
|
|
text.append("\n\n")
|
|
text.append("Remediation", style=field_style)
|
|
text.append("\n")
|
|
text.append(remediation_steps)
|
|
|
|
return text
|
|
|
|
|
|
def _build_vulnerability_stats(stats_text: Text, tracer: Any) -> None:
|
|
"""Build vulnerability section of stats text."""
|
|
vuln_count = len(tracer.vulnerability_reports)
|
|
|
|
if vuln_count > 0:
|
|
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
|
for report in tracer.vulnerability_reports:
|
|
severity = report.get("severity", "").lower()
|
|
if severity in severity_counts:
|
|
severity_counts[severity] += 1
|
|
|
|
stats_text.append("Vulnerabilities ", style="bold red")
|
|
|
|
severity_parts = []
|
|
for severity in ["critical", "high", "medium", "low", "info"]:
|
|
count = severity_counts[severity]
|
|
if count > 0:
|
|
severity_color = get_severity_color(severity)
|
|
severity_text = Text()
|
|
severity_text.append(f"{severity.upper()}: ", style=severity_color)
|
|
severity_text.append(str(count), style=f"bold {severity_color}")
|
|
severity_parts.append(severity_text)
|
|
|
|
for i, part in enumerate(severity_parts):
|
|
stats_text.append(part)
|
|
if i < len(severity_parts) - 1:
|
|
stats_text.append(" | ", style="dim white")
|
|
|
|
stats_text.append(" (Total: ", style="dim white")
|
|
stats_text.append(str(vuln_count), style="bold yellow")
|
|
stats_text.append(")", style="dim white")
|
|
stats_text.append("\n")
|
|
else:
|
|
stats_text.append("Vulnerabilities ", style="bold #22c55e")
|
|
stats_text.append("0", style="bold white")
|
|
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
|
|
stats_text.append("\n")
|
|
|
|
|
|
def _build_llm_stats(stats_text: Text, total_stats: dict[str, Any]) -> None:
|
|
"""Build LLM usage section of stats text."""
|
|
if total_stats["requests"] > 0:
|
|
stats_text.append("\n")
|
|
stats_text.append("Input Tokens ", style="dim")
|
|
stats_text.append(format_token_count(total_stats["input_tokens"]), style="white")
|
|
|
|
if total_stats["cached_tokens"] > 0:
|
|
stats_text.append(" · ", style="dim white")
|
|
stats_text.append("Cached Tokens ", style="dim")
|
|
stats_text.append(format_token_count(total_stats["cached_tokens"]), style="white")
|
|
|
|
stats_text.append(" · ", style="dim white")
|
|
stats_text.append("Output Tokens ", style="dim")
|
|
stats_text.append(format_token_count(total_stats["output_tokens"]), style="white")
|
|
|
|
if total_stats["cost"] > 0:
|
|
stats_text.append(" · ", style="dim white")
|
|
stats_text.append("Cost ", style="dim")
|
|
stats_text.append(f"${total_stats['cost']:.4f}", style="bold #fbbf24")
|
|
else:
|
|
stats_text.append("\n")
|
|
stats_text.append("Cost ", style="dim")
|
|
stats_text.append("$0.0000 ", style="#fbbf24")
|
|
stats_text.append("· ", style="dim white")
|
|
stats_text.append("Tokens ", style="dim")
|
|
stats_text.append("0", style="white")
|
|
|
|
|
|
def build_final_stats_text(tracer: Any) -> Text:
|
|
"""Build stats text for final output with detailed messages and LLM usage."""
|
|
stats_text = Text()
|
|
if not tracer:
|
|
return stats_text
|
|
|
|
_build_vulnerability_stats(stats_text, tracer)
|
|
|
|
tool_count = tracer.get_real_tool_count()
|
|
agent_count = len(tracer.agents)
|
|
|
|
stats_text.append("Agents", style="dim")
|
|
stats_text.append(" ")
|
|
stats_text.append(str(agent_count), style="bold white")
|
|
stats_text.append(" · ", style="dim white")
|
|
stats_text.append("Tools", style="dim")
|
|
stats_text.append(" ")
|
|
stats_text.append(str(tool_count), style="bold white")
|
|
|
|
llm_stats = tracer.get_total_llm_stats()
|
|
_build_llm_stats(stats_text, llm_stats["total"])
|
|
|
|
return stats_text
|
|
|
|
|
|
def build_live_stats_text(tracer: Any, agent_config: dict[str, Any] | None = None) -> Text:
|
|
stats_text = Text()
|
|
if not tracer:
|
|
return stats_text
|
|
|
|
if agent_config:
|
|
llm_config = agent_config["llm_config"]
|
|
model = getattr(llm_config, "model_name", "Unknown")
|
|
stats_text.append("Model ", style="dim")
|
|
stats_text.append(model, style="white")
|
|
stats_text.append("\n")
|
|
|
|
vuln_count = len(tracer.vulnerability_reports)
|
|
tool_count = tracer.get_real_tool_count()
|
|
agent_count = len(tracer.agents)
|
|
|
|
stats_text.append("Vulnerabilities ", style="dim")
|
|
stats_text.append(f"{vuln_count}", style="white")
|
|
stats_text.append("\n")
|
|
if vuln_count > 0:
|
|
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
|
for report in tracer.vulnerability_reports:
|
|
severity = report.get("severity", "").lower()
|
|
if severity in severity_counts:
|
|
severity_counts[severity] += 1
|
|
|
|
severity_parts = []
|
|
for severity in ["critical", "high", "medium", "low", "info"]:
|
|
count = severity_counts[severity]
|
|
if count > 0:
|
|
severity_color = get_severity_color(severity)
|
|
severity_text = Text()
|
|
severity_text.append(f"{severity.upper()}: ", style=severity_color)
|
|
severity_text.append(str(count), style=f"bold {severity_color}")
|
|
severity_parts.append(severity_text)
|
|
|
|
for i, part in enumerate(severity_parts):
|
|
stats_text.append(part)
|
|
if i < len(severity_parts) - 1:
|
|
stats_text.append(" | ", style="dim white")
|
|
|
|
stats_text.append("\n")
|
|
|
|
stats_text.append("Agents ", style="dim")
|
|
stats_text.append(str(agent_count), style="white")
|
|
stats_text.append(" · ", style="dim white")
|
|
stats_text.append("Tools ", style="dim")
|
|
stats_text.append(str(tool_count), style="white")
|
|
|
|
llm_stats = tracer.get_total_llm_stats()
|
|
total_stats = llm_stats["total"]
|
|
|
|
stats_text.append("\n")
|
|
|
|
stats_text.append("Input Tokens ", style="dim")
|
|
stats_text.append(format_token_count(total_stats["input_tokens"]), style="white")
|
|
|
|
stats_text.append(" · ", style="dim white")
|
|
stats_text.append("Cached Tokens ", style="dim")
|
|
stats_text.append(format_token_count(total_stats["cached_tokens"]), style="white")
|
|
|
|
stats_text.append("\n")
|
|
|
|
stats_text.append("Output Tokens ", style="dim")
|
|
stats_text.append(format_token_count(total_stats["output_tokens"]), style="white")
|
|
|
|
stats_text.append(" · ", style="dim white")
|
|
stats_text.append("Cost ", style="dim")
|
|
stats_text.append(f"${total_stats['cost']:.4f}", style="#fbbf24")
|
|
|
|
return stats_text
|
|
|
|
|
|
def build_tui_stats_text(tracer: Any, agent_config: dict[str, Any] | None = None) -> Text:
|
|
stats_text = Text()
|
|
if not tracer:
|
|
return stats_text
|
|
|
|
if agent_config:
|
|
llm_config = agent_config["llm_config"]
|
|
model = getattr(llm_config, "model_name", "Unknown")
|
|
stats_text.append(model, style="white")
|
|
|
|
llm_stats = tracer.get_total_llm_stats()
|
|
total_stats = llm_stats["total"]
|
|
|
|
total_tokens = total_stats["input_tokens"] + total_stats["output_tokens"]
|
|
if total_tokens > 0:
|
|
stats_text.append("\n")
|
|
stats_text.append(f"{format_token_count(total_tokens)} tokens", style="white")
|
|
|
|
if total_stats["cost"] > 0:
|
|
stats_text.append(" · ", style="white")
|
|
stats_text.append(f"${total_stats['cost']:.2f}", style="white")
|
|
|
|
caido_url = getattr(tracer, "caido_url", None)
|
|
if caido_url:
|
|
stats_text.append("\n")
|
|
stats_text.append("Caido: ", style="bold white")
|
|
stats_text.append(caido_url, style="white")
|
|
|
|
return stats_text
|
|
|
|
|
|
# Name generation utilities
|
|
|
|
|
|
def _slugify_for_run_name(text: str, max_length: int = 32) -> str:
|
|
text = text.lower().strip()
|
|
text = re.sub(r"[^a-z0-9]+", "-", text)
|
|
text = text.strip("-")
|
|
if len(text) > max_length:
|
|
text = text[:max_length].rstrip("-")
|
|
return text or "pentest"
|
|
|
|
|
|
def _derive_target_label_for_run_name(targets_info: list[dict[str, Any]] | None) -> str: # noqa: PLR0911
|
|
if not targets_info:
|
|
return "pentest"
|
|
|
|
first = targets_info[0]
|
|
target_type = first.get("type")
|
|
details = first.get("details", {}) or {}
|
|
original = first.get("original", "") or ""
|
|
|
|
if target_type == "web_application":
|
|
url = details.get("target_url", original)
|
|
try:
|
|
parsed = urlparse(url)
|
|
return str(parsed.netloc or parsed.path or url)
|
|
except Exception: # noqa: BLE001
|
|
return str(url)
|
|
|
|
if target_type == "repository":
|
|
repo = details.get("target_repo", original)
|
|
parsed = urlparse(repo)
|
|
path = parsed.path or repo
|
|
name = path.rstrip("/").split("/")[-1] or path
|
|
if name.endswith(".git"):
|
|
name = name[:-4]
|
|
return str(name)
|
|
|
|
if target_type == "local_code":
|
|
path_str = details.get("target_path", original)
|
|
try:
|
|
return str(Path(path_str).name or path_str)
|
|
except Exception: # noqa: BLE001
|
|
return str(path_str)
|
|
|
|
if target_type == "ip_address":
|
|
return str(details.get("target_ip", original) or original)
|
|
|
|
return str(original or "pentest")
|
|
|
|
|
|
def generate_run_name(targets_info: list[dict[str, Any]] | None = None) -> str:
|
|
base_label = _derive_target_label_for_run_name(targets_info)
|
|
slug = _slugify_for_run_name(base_label)
|
|
|
|
random_suffix = secrets.token_hex(2)
|
|
|
|
return f"{slug}_{random_suffix}"
|
|
|
|
|
|
# Target processing utilities
|
|
|
|
_SUPPORTED_SCOPE_MODES = {"auto", "diff", "full"}
|
|
_MAX_FILES_PER_SECTION = 120
|
|
|
|
|
|
@dataclass
|
|
class DiffEntry:
|
|
status: str
|
|
path: str
|
|
old_path: str | None = None
|
|
similarity: int | None = None
|
|
|
|
|
|
@dataclass
|
|
class RepoDiffScope:
|
|
source_path: str
|
|
workspace_subdir: str | None
|
|
base_ref: str
|
|
merge_base: str
|
|
added_files: list[str]
|
|
modified_files: list[str]
|
|
renamed_files: list[dict[str, Any]]
|
|
deleted_files: list[str]
|
|
analyzable_files: list[str]
|
|
truncated_sections: dict[str, bool] = field(default_factory=dict)
|
|
|
|
def to_metadata(self) -> dict[str, Any]:
|
|
return {
|
|
"source_path": self.source_path,
|
|
"workspace_subdir": self.workspace_subdir,
|
|
"base_ref": self.base_ref,
|
|
"merge_base": self.merge_base,
|
|
"added_files": self.added_files,
|
|
"modified_files": self.modified_files,
|
|
"renamed_files": self.renamed_files,
|
|
"deleted_files": self.deleted_files,
|
|
"analyzable_files": self.analyzable_files,
|
|
"added_files_count": len(self.added_files),
|
|
"modified_files_count": len(self.modified_files),
|
|
"renamed_files_count": len(self.renamed_files),
|
|
"deleted_files_count": len(self.deleted_files),
|
|
"analyzable_files_count": len(self.analyzable_files),
|
|
"truncated_sections": self.truncated_sections,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class DiffScopeResult:
|
|
active: bool
|
|
mode: str
|
|
instruction_block: str = ""
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
def _run_git_command(
|
|
repo_path: Path, args: list[str], check: bool = True
|
|
) -> subprocess.CompletedProcess[str]:
|
|
return subprocess.run( # noqa: S603
|
|
["git", "-C", str(repo_path), *args], # noqa: S607
|
|
capture_output=True,
|
|
text=True,
|
|
check=check,
|
|
)
|
|
|
|
|
|
def _run_git_command_raw(
|
|
repo_path: Path, args: list[str], check: bool = True
|
|
) -> subprocess.CompletedProcess[bytes]:
|
|
return subprocess.run( # noqa: S603
|
|
["git", "-C", str(repo_path), *args], # noqa: S607
|
|
capture_output=True,
|
|
check=check,
|
|
)
|
|
|
|
|
|
def _is_ci_environment(env: dict[str, str]) -> bool:
|
|
return any(
|
|
env.get(key)
|
|
for key in (
|
|
"CI",
|
|
"GITHUB_ACTIONS",
|
|
"GITLAB_CI",
|
|
"JENKINS_URL",
|
|
"BUILDKITE",
|
|
"CIRCLECI",
|
|
)
|
|
)
|
|
|
|
|
|
def _is_pr_environment(env: dict[str, str]) -> bool:
|
|
return any(
|
|
env.get(key)
|
|
for key in (
|
|
"GITHUB_BASE_REF",
|
|
"GITHUB_HEAD_REF",
|
|
"CI_MERGE_REQUEST_TARGET_BRANCH_NAME",
|
|
"GITLAB_MERGE_REQUEST_TARGET_BRANCH_NAME",
|
|
"SYSTEM_PULLREQUEST_TARGETBRANCH",
|
|
)
|
|
)
|
|
|
|
|
|
def _is_git_repo(repo_path: Path) -> bool:
|
|
result = _run_git_command(repo_path, ["rev-parse", "--is-inside-work-tree"], check=False)
|
|
return result.returncode == 0 and result.stdout.strip().lower() == "true"
|
|
|
|
|
|
def _is_repo_shallow(repo_path: Path) -> bool:
|
|
result = _run_git_command(repo_path, ["rev-parse", "--is-shallow-repository"], check=False)
|
|
if result.returncode == 0:
|
|
value = result.stdout.strip().lower()
|
|
if value in {"true", "false"}:
|
|
return value == "true"
|
|
|
|
git_meta = repo_path / ".git"
|
|
if git_meta.is_dir():
|
|
return (git_meta / "shallow").exists()
|
|
if git_meta.is_file():
|
|
try:
|
|
content = git_meta.read_text(encoding="utf-8").strip()
|
|
except OSError:
|
|
return False
|
|
if content.startswith("gitdir:"):
|
|
git_dir = content.split(":", 1)[1].strip()
|
|
resolved = (repo_path / git_dir).resolve()
|
|
return (resolved / "shallow").exists()
|
|
return False
|
|
|
|
|
|
def _git_ref_exists(repo_path: Path, ref: str) -> bool:
|
|
result = _run_git_command(repo_path, ["rev-parse", "--verify", "--quiet", ref], check=False)
|
|
return result.returncode == 0
|
|
|
|
|
|
def _resolve_origin_head_ref(repo_path: Path) -> str | None:
|
|
result = _run_git_command(
|
|
repo_path, ["symbolic-ref", "--quiet", "refs/remotes/origin/HEAD"], check=False
|
|
)
|
|
if result.returncode != 0:
|
|
return None
|
|
ref = result.stdout.strip()
|
|
return ref or None
|
|
|
|
|
|
def _extract_branch_name(ref: str | None) -> str | None:
|
|
if not ref:
|
|
return None
|
|
value = ref.strip()
|
|
if not value:
|
|
return None
|
|
return value.split("/")[-1]
|
|
|
|
|
|
def _extract_github_base_sha(env: dict[str, str]) -> str | None:
|
|
event_path = env.get("GITHUB_EVENT_PATH", "").strip()
|
|
if not event_path:
|
|
return None
|
|
|
|
path = Path(event_path)
|
|
if not path.exists():
|
|
return None
|
|
|
|
try:
|
|
payload = json.loads(path.read_text(encoding="utf-8"))
|
|
except (json.JSONDecodeError, OSError):
|
|
return None
|
|
|
|
base_sha = payload.get("pull_request", {}).get("base", {}).get("sha")
|
|
if isinstance(base_sha, str) and base_sha.strip():
|
|
return base_sha.strip()
|
|
return None
|
|
|
|
|
|
def _resolve_default_branch_name(repo_path: Path, env: dict[str, str]) -> str | None:
|
|
github_base_ref = env.get("GITHUB_BASE_REF", "").strip()
|
|
if github_base_ref:
|
|
return github_base_ref
|
|
|
|
origin_head = _resolve_origin_head_ref(repo_path)
|
|
if origin_head:
|
|
branch = _extract_branch_name(origin_head)
|
|
if branch:
|
|
return branch
|
|
|
|
if _git_ref_exists(repo_path, "refs/remotes/origin/main"):
|
|
return "main"
|
|
if _git_ref_exists(repo_path, "refs/remotes/origin/master"):
|
|
return "master"
|
|
|
|
return None
|
|
|
|
|
|
def _resolve_base_ref(repo_path: Path, diff_base: str | None, env: dict[str, str]) -> str:
|
|
if diff_base and diff_base.strip():
|
|
return diff_base.strip()
|
|
|
|
github_base_ref = env.get("GITHUB_BASE_REF", "").strip()
|
|
if github_base_ref:
|
|
github_candidate = f"refs/remotes/origin/{github_base_ref}"
|
|
if _git_ref_exists(repo_path, github_candidate):
|
|
return github_candidate
|
|
|
|
github_base_sha = _extract_github_base_sha(env)
|
|
if github_base_sha and _git_ref_exists(repo_path, github_base_sha):
|
|
return github_base_sha
|
|
|
|
origin_head = _resolve_origin_head_ref(repo_path)
|
|
if origin_head and _git_ref_exists(repo_path, origin_head):
|
|
return origin_head
|
|
|
|
if _git_ref_exists(repo_path, "refs/remotes/origin/main"):
|
|
return "refs/remotes/origin/main"
|
|
|
|
if _git_ref_exists(repo_path, "refs/remotes/origin/master"):
|
|
return "refs/remotes/origin/master"
|
|
|
|
raise ValueError(
|
|
"Unable to resolve a base ref for diff-scope. Pass --diff-base explicitly "
|
|
"(for example: --diff-base origin/main)."
|
|
)
|
|
|
|
|
|
def _get_current_branch_name(repo_path: Path) -> str | None:
|
|
result = _run_git_command(repo_path, ["rev-parse", "--abbrev-ref", "HEAD"], check=False)
|
|
if result.returncode != 0:
|
|
return None
|
|
branch_name = result.stdout.strip()
|
|
if not branch_name or branch_name == "HEAD":
|
|
return None
|
|
return branch_name
|
|
|
|
|
|
def _parse_name_status_z(raw_output: bytes) -> list[DiffEntry]:
|
|
if not raw_output:
|
|
return []
|
|
|
|
tokens = [
|
|
token.decode("utf-8", errors="replace") for token in raw_output.split(b"\x00") if token
|
|
]
|
|
entries: list[DiffEntry] = []
|
|
index = 0
|
|
|
|
while index < len(tokens):
|
|
token = tokens[index]
|
|
status_raw = token
|
|
status_code = status_raw[:1]
|
|
similarity: int | None = None
|
|
if len(status_raw) > 1 and status_raw[1:].isdigit():
|
|
similarity = int(status_raw[1:])
|
|
|
|
# Git's -z output for --name-status is:
|
|
# - non-rename/copy: <status>\0<path>\0
|
|
# - rename/copy: <statusN>\0<old_path>\0<new_path>\0
|
|
if status_code in {"R", "C"} and index + 2 < len(tokens):
|
|
old_path = tokens[index + 1]
|
|
new_path = tokens[index + 2]
|
|
entries.append(
|
|
DiffEntry(
|
|
status=status_code,
|
|
path=new_path,
|
|
old_path=old_path,
|
|
similarity=similarity,
|
|
)
|
|
)
|
|
index += 3
|
|
continue
|
|
|
|
if index + 1 < len(tokens):
|
|
path = tokens[index + 1]
|
|
entries.append(DiffEntry(status=status_code, path=path, similarity=similarity))
|
|
index += 2
|
|
continue
|
|
|
|
# Backward-compat fallback if output is tab-delimited unexpectedly.
|
|
status_fallback, has_tab, first_path = token.partition("\t")
|
|
if not has_tab:
|
|
break
|
|
fallback_code = status_fallback[:1]
|
|
fallback_similarity: int | None = None
|
|
if len(status_fallback) > 1 and status_fallback[1:].isdigit():
|
|
fallback_similarity = int(status_fallback[1:])
|
|
entries.append(
|
|
DiffEntry(status=fallback_code, path=first_path, similarity=fallback_similarity)
|
|
)
|
|
index += 1
|
|
|
|
return entries
|
|
|
|
|
|
def _append_unique(container: list[str], seen: set[str], path: str) -> None:
|
|
if path and path not in seen:
|
|
seen.add(path)
|
|
container.append(path)
|
|
|
|
|
|
def _classify_diff_entries(entries: list[DiffEntry]) -> dict[str, Any]:
|
|
added_files: list[str] = []
|
|
modified_files: list[str] = []
|
|
deleted_files: list[str] = []
|
|
renamed_files: list[dict[str, Any]] = []
|
|
analyzable_files: list[str] = []
|
|
analyzable_seen: set[str] = set()
|
|
modified_seen: set[str] = set()
|
|
|
|
for entry in entries:
|
|
path = entry.path
|
|
if not path:
|
|
continue
|
|
|
|
if entry.status == "D":
|
|
deleted_files.append(path)
|
|
continue
|
|
|
|
if entry.status == "A":
|
|
added_files.append(path)
|
|
_append_unique(analyzable_files, analyzable_seen, path)
|
|
continue
|
|
|
|
if entry.status == "M":
|
|
_append_unique(modified_files, modified_seen, path)
|
|
_append_unique(analyzable_files, analyzable_seen, path)
|
|
continue
|
|
|
|
if entry.status == "R":
|
|
renamed_files.append(
|
|
{
|
|
"old_path": entry.old_path,
|
|
"new_path": path,
|
|
"similarity": entry.similarity,
|
|
}
|
|
)
|
|
_append_unique(analyzable_files, analyzable_seen, path)
|
|
if entry.similarity is None or entry.similarity < 100:
|
|
_append_unique(modified_files, modified_seen, path)
|
|
continue
|
|
|
|
if entry.status == "C":
|
|
_append_unique(modified_files, modified_seen, path)
|
|
_append_unique(analyzable_files, analyzable_seen, path)
|
|
continue
|
|
|
|
_append_unique(modified_files, modified_seen, path)
|
|
_append_unique(analyzable_files, analyzable_seen, path)
|
|
|
|
return {
|
|
"added_files": added_files,
|
|
"modified_files": modified_files,
|
|
"deleted_files": deleted_files,
|
|
"renamed_files": renamed_files,
|
|
"analyzable_files": analyzable_files,
|
|
}
|
|
|
|
|
|
def _truncate_file_list(
|
|
files: list[str], max_files: int = _MAX_FILES_PER_SECTION
|
|
) -> tuple[list[str], bool]:
|
|
if len(files) <= max_files:
|
|
return files, False
|
|
return files[:max_files], True
|
|
|
|
|
|
def build_diff_scope_instruction(scopes: list[RepoDiffScope]) -> str: # noqa: PLR0912
|
|
lines = [
|
|
"The user is requesting a review of a Pull Request.",
|
|
"Instruction: Direct your analysis primarily at the changes in the listed files. "
|
|
"You may reference other files in the repository for context (imports, definitions, "
|
|
"usage), but report findings only if they relate to the listed changes.",
|
|
"For Added files, review the entire file content.",
|
|
"For Modified files, focus primarily on the changed areas.",
|
|
]
|
|
|
|
for scope in scopes:
|
|
repo_name = scope.workspace_subdir or Path(scope.source_path).name or "repository"
|
|
lines.append("")
|
|
lines.append(f"Repository Scope: {repo_name}")
|
|
lines.append(f"Base reference: {scope.base_ref}")
|
|
lines.append(f"Merge base: {scope.merge_base}")
|
|
|
|
focus_files, focus_truncated = _truncate_file_list(scope.analyzable_files)
|
|
scope.truncated_sections["analyzable_files"] = focus_truncated
|
|
if focus_files:
|
|
lines.append("Primary Focus (changed files to analyze):")
|
|
lines.extend(f"- {path}" for path in focus_files)
|
|
if focus_truncated:
|
|
lines.append(f"- ... ({len(scope.analyzable_files) - len(focus_files)} more files)")
|
|
else:
|
|
lines.append("Primary Focus: No analyzable changed files detected.")
|
|
|
|
added_files, added_truncated = _truncate_file_list(scope.added_files)
|
|
scope.truncated_sections["added_files"] = added_truncated
|
|
if added_files:
|
|
lines.append("Added files (review entire file):")
|
|
lines.extend(f"- {path}" for path in added_files)
|
|
if added_truncated:
|
|
lines.append(f"- ... ({len(scope.added_files) - len(added_files)} more files)")
|
|
|
|
modified_files, modified_truncated = _truncate_file_list(scope.modified_files)
|
|
scope.truncated_sections["modified_files"] = modified_truncated
|
|
if modified_files:
|
|
lines.append("Modified files (focus on changes):")
|
|
lines.extend(f"- {path}" for path in modified_files)
|
|
if modified_truncated:
|
|
lines.append(
|
|
f"- ... ({len(scope.modified_files) - len(modified_files)} more files)"
|
|
)
|
|
|
|
if scope.renamed_files:
|
|
rename_lines = []
|
|
for rename in scope.renamed_files:
|
|
old_path = rename.get("old_path") or "unknown"
|
|
new_path = rename.get("new_path") or "unknown"
|
|
similarity = rename.get("similarity")
|
|
if isinstance(similarity, int):
|
|
rename_lines.append(f"- {old_path} -> {new_path} (similarity {similarity}%)")
|
|
else:
|
|
rename_lines.append(f"- {old_path} -> {new_path}")
|
|
lines.append("Renamed files:")
|
|
lines.extend(rename_lines)
|
|
|
|
deleted_files, deleted_truncated = _truncate_file_list(scope.deleted_files)
|
|
scope.truncated_sections["deleted_files"] = deleted_truncated
|
|
if deleted_files:
|
|
lines.append("Note: These files were deleted (context only, not analyzable):")
|
|
lines.extend(f"- {path}" for path in deleted_files)
|
|
if deleted_truncated:
|
|
lines.append(f"- ... ({len(scope.deleted_files) - len(deleted_files)} more files)")
|
|
|
|
return "\n".join(lines).strip()
|
|
|
|
|
|
def _should_activate_auto_scope(
|
|
local_sources: list[dict[str, str]], non_interactive: bool, env: dict[str, str]
|
|
) -> bool:
|
|
if not local_sources:
|
|
return False
|
|
if not non_interactive:
|
|
return False
|
|
if not _is_ci_environment(env):
|
|
return False
|
|
if _is_pr_environment(env):
|
|
return True
|
|
|
|
for source in local_sources:
|
|
source_path = source.get("source_path")
|
|
if not source_path:
|
|
continue
|
|
repo_path = Path(source_path)
|
|
if not _is_git_repo(repo_path):
|
|
continue
|
|
current_branch = _get_current_branch_name(repo_path)
|
|
default_branch = _resolve_default_branch_name(repo_path, env)
|
|
if current_branch and default_branch and current_branch != default_branch:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _resolve_repo_diff_scope(
|
|
source: dict[str, str], diff_base: str | None, env: dict[str, str]
|
|
) -> RepoDiffScope:
|
|
source_path = source.get("source_path", "")
|
|
workspace_subdir = source.get("workspace_subdir")
|
|
repo_path = Path(source_path)
|
|
|
|
if not _is_git_repo(repo_path):
|
|
raise ValueError(f"Source is not a git repository: {source_path}")
|
|
|
|
if _is_repo_shallow(repo_path):
|
|
raise ValueError(
|
|
"Strix requires full git history for diff-scope. Please set fetch-depth: 0 "
|
|
"in your CI config."
|
|
)
|
|
|
|
base_ref = _resolve_base_ref(repo_path, diff_base, env)
|
|
merge_base_result = _run_git_command(repo_path, ["merge-base", base_ref, "HEAD"], check=False)
|
|
if merge_base_result.returncode != 0:
|
|
stderr = merge_base_result.stderr.strip()
|
|
raise ValueError(
|
|
f"Unable to compute merge-base against '{base_ref}' for '{source_path}'. "
|
|
f"{stderr or 'Ensure the base branch history is fetched and reachable.'}"
|
|
)
|
|
|
|
merge_base = merge_base_result.stdout.strip()
|
|
if not merge_base:
|
|
raise ValueError(
|
|
f"Unable to compute merge-base against '{base_ref}' for '{source_path}'. "
|
|
"Ensure the base branch history is fetched and reachable."
|
|
)
|
|
|
|
diff_result = _run_git_command_raw(
|
|
repo_path,
|
|
[
|
|
"diff",
|
|
"--name-status",
|
|
"-z",
|
|
"--find-renames",
|
|
"--find-copies",
|
|
f"{merge_base}...HEAD",
|
|
],
|
|
check=False,
|
|
)
|
|
if diff_result.returncode != 0:
|
|
stderr = diff_result.stderr.decode("utf-8", errors="replace").strip()
|
|
raise ValueError(
|
|
f"Unable to resolve changed files for '{source_path}'. "
|
|
f"{stderr or 'Ensure the repository has enough history for diff-scope.'}"
|
|
)
|
|
|
|
entries = _parse_name_status_z(diff_result.stdout)
|
|
classified = _classify_diff_entries(entries)
|
|
|
|
return RepoDiffScope(
|
|
source_path=source_path,
|
|
workspace_subdir=workspace_subdir,
|
|
base_ref=base_ref,
|
|
merge_base=merge_base,
|
|
added_files=classified["added_files"],
|
|
modified_files=classified["modified_files"],
|
|
renamed_files=classified["renamed_files"],
|
|
deleted_files=classified["deleted_files"],
|
|
analyzable_files=classified["analyzable_files"],
|
|
)
|
|
|
|
|
|
def resolve_diff_scope_context(
|
|
local_sources: list[dict[str, str]],
|
|
scope_mode: str,
|
|
diff_base: str | None,
|
|
non_interactive: bool,
|
|
env: dict[str, str] | None = None,
|
|
) -> DiffScopeResult:
|
|
if scope_mode not in _SUPPORTED_SCOPE_MODES:
|
|
raise ValueError(f"Unsupported scope mode: {scope_mode}")
|
|
|
|
env_map = dict(os.environ if env is None else env)
|
|
|
|
if scope_mode == "full":
|
|
return DiffScopeResult(
|
|
active=False,
|
|
mode=scope_mode,
|
|
metadata={"active": False, "mode": scope_mode},
|
|
)
|
|
|
|
if scope_mode == "auto":
|
|
should_activate = _should_activate_auto_scope(local_sources, non_interactive, env_map)
|
|
if not should_activate:
|
|
return DiffScopeResult(
|
|
active=False,
|
|
mode=scope_mode,
|
|
metadata={"active": False, "mode": scope_mode},
|
|
)
|
|
|
|
if not local_sources:
|
|
raise ValueError("Diff-scope is active, but no local repository targets were provided.")
|
|
|
|
repo_scopes: list[RepoDiffScope] = []
|
|
skipped_non_git: list[str] = []
|
|
skipped_diff_scope: list[str] = []
|
|
for source in local_sources:
|
|
source_path = source.get("source_path")
|
|
if not source_path:
|
|
continue
|
|
if not _is_git_repo(Path(source_path)):
|
|
skipped_non_git.append(source_path)
|
|
continue
|
|
try:
|
|
repo_scopes.append(_resolve_repo_diff_scope(source, diff_base, env_map))
|
|
except ValueError as e:
|
|
if scope_mode == "auto":
|
|
skipped_diff_scope.append(f"{source_path} (diff-scope skipped: {e})")
|
|
continue
|
|
raise
|
|
|
|
if not repo_scopes:
|
|
if scope_mode == "auto":
|
|
metadata: dict[str, Any] = {"active": False, "mode": scope_mode}
|
|
if skipped_non_git:
|
|
metadata["skipped_non_git_sources"] = skipped_non_git
|
|
if skipped_diff_scope:
|
|
metadata["skipped_diff_scope_sources"] = skipped_diff_scope
|
|
return DiffScopeResult(active=False, mode=scope_mode, metadata=metadata)
|
|
|
|
raise ValueError(
|
|
"Diff-scope is active, but no Git repositories were found. "
|
|
"Use --scope-mode full to disable diff-scope for this run."
|
|
)
|
|
|
|
instruction_block = build_diff_scope_instruction(repo_scopes)
|
|
metadata: dict[str, Any] = {
|
|
"active": True,
|
|
"mode": scope_mode,
|
|
"repos": [scope.to_metadata() for scope in repo_scopes],
|
|
"total_repositories": len(repo_scopes),
|
|
"total_analyzable_files": sum(len(scope.analyzable_files) for scope in repo_scopes),
|
|
"total_deleted_files": sum(len(scope.deleted_files) for scope in repo_scopes),
|
|
}
|
|
if skipped_non_git:
|
|
metadata["skipped_non_git_sources"] = skipped_non_git
|
|
if skipped_diff_scope:
|
|
metadata["skipped_diff_scope_sources"] = skipped_diff_scope
|
|
|
|
return DiffScopeResult(
|
|
active=True,
|
|
mode=scope_mode,
|
|
instruction_block=instruction_block,
|
|
metadata=metadata,
|
|
)
|
|
|
|
|
|
def _is_http_git_repo(url: str) -> bool:
|
|
check_url = f"{url.rstrip('/')}/info/refs?service=git-upload-pack"
|
|
try:
|
|
req = Request(check_url, headers={"User-Agent": "git/strix"}) # noqa: S310
|
|
with urlopen(req, timeout=10) as resp: # noqa: S310 # nosec B310
|
|
return "x-git-upload-pack-advertisement" in resp.headers.get("Content-Type", "")
|
|
except HTTPError as e:
|
|
return e.code == 401
|
|
except (URLError, OSError, ValueError):
|
|
return False
|
|
|
|
|
|
def infer_target_type(target: str) -> tuple[str, dict[str, str]]: # noqa: PLR0911, PLR0912
|
|
if not target or not isinstance(target, str):
|
|
raise ValueError("Target must be a non-empty string")
|
|
|
|
target = target.strip()
|
|
|
|
if target.startswith("git@"):
|
|
return "repository", {"target_repo": target}
|
|
|
|
if target.startswith("git://"):
|
|
return "repository", {"target_repo": target}
|
|
|
|
parsed = urlparse(target)
|
|
if parsed.scheme in ("http", "https"):
|
|
if parsed.username or parsed.password:
|
|
return "repository", {"target_repo": target}
|
|
if parsed.path.rstrip("/").endswith(".git"):
|
|
return "repository", {"target_repo": target}
|
|
if parsed.query or parsed.fragment:
|
|
return "web_application", {"target_url": target}
|
|
path_segments = [s for s in parsed.path.split("/") if s]
|
|
if len(path_segments) >= 2 and _is_http_git_repo(target):
|
|
return "repository", {"target_repo": target}
|
|
return "web_application", {"target_url": target}
|
|
|
|
try:
|
|
ip_obj = ipaddress.ip_address(target)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
return "ip_address", {"target_ip": str(ip_obj)}
|
|
|
|
path = Path(target).expanduser()
|
|
try:
|
|
if path.exists():
|
|
if path.is_dir():
|
|
return "local_code", {"target_path": str(path.resolve())}
|
|
raise ValueError(f"Path exists but is not a directory: {target}")
|
|
except (OSError, RuntimeError) as e:
|
|
raise ValueError(f"Invalid path: {target} - {e!s}") from e
|
|
|
|
if target.endswith(".git"):
|
|
return "repository", {"target_repo": target}
|
|
|
|
if "/" in target:
|
|
host_part, _, path_part = target.partition("/")
|
|
if "." in host_part and not host_part.startswith(".") and path_part:
|
|
full_url = f"https://{target}"
|
|
if _is_http_git_repo(full_url):
|
|
return "repository", {"target_repo": full_url}
|
|
return "web_application", {"target_url": full_url}
|
|
|
|
if "." in target and "/" not in target and not target.startswith("."):
|
|
parts = target.split(".")
|
|
if len(parts) >= 2 and all(p and p.strip() for p in parts):
|
|
return "web_application", {"target_url": f"https://{target}"}
|
|
|
|
raise ValueError(
|
|
f"Invalid target: {target}\n"
|
|
"Target must be one of:\n"
|
|
"- A valid URL (http:// or https://)\n"
|
|
"- A Git repository URL (https://host/org/repo or git@host:org/repo.git)\n"
|
|
"- A local directory path\n"
|
|
"- A domain name (e.g., example.com)\n"
|
|
"- An IP address (e.g., 192.168.1.10)"
|
|
)
|
|
|
|
|
|
def sanitize_name(name: str) -> str:
|
|
sanitized = re.sub(r"[^A-Za-z0-9._-]", "-", name.strip())
|
|
return sanitized or "target"
|
|
|
|
|
|
def derive_repo_base_name(repo_url: str) -> str:
|
|
if repo_url.endswith("/"):
|
|
repo_url = repo_url[:-1]
|
|
|
|
if ":" in repo_url and repo_url.startswith("git@"):
|
|
path_part = repo_url.split(":", 1)[1]
|
|
else:
|
|
path_part = urlparse(repo_url).path or repo_url
|
|
|
|
candidate = path_part.split("/")[-1]
|
|
if candidate.endswith(".git"):
|
|
candidate = candidate[:-4]
|
|
|
|
return sanitize_name(candidate or "repository")
|
|
|
|
|
|
def derive_local_base_name(path_str: str) -> str:
|
|
try:
|
|
base = Path(path_str).resolve().name
|
|
except (OSError, RuntimeError):
|
|
base = Path(path_str).name
|
|
return sanitize_name(base or "workspace")
|
|
|
|
|
|
def assign_workspace_subdirs(targets_info: list[dict[str, Any]]) -> None:
|
|
name_counts: dict[str, int] = {}
|
|
|
|
for target in targets_info:
|
|
target_type = target["type"]
|
|
details = target["details"]
|
|
|
|
base_name: str | None = None
|
|
if target_type == "repository":
|
|
base_name = derive_repo_base_name(details["target_repo"])
|
|
elif target_type == "local_code":
|
|
base_name = derive_local_base_name(details.get("target_path", "local"))
|
|
|
|
if base_name is None:
|
|
continue
|
|
|
|
count = name_counts.get(base_name, 0) + 1
|
|
name_counts[base_name] = count
|
|
|
|
workspace_subdir = base_name if count == 1 else f"{base_name}-{count}"
|
|
|
|
details["workspace_subdir"] = workspace_subdir
|
|
|
|
|
|
def collect_local_sources(targets_info: list[dict[str, Any]]) -> list[dict[str, str]]:
|
|
local_sources: list[dict[str, str]] = []
|
|
|
|
for target_info in targets_info:
|
|
details = target_info["details"]
|
|
workspace_subdir = details.get("workspace_subdir")
|
|
|
|
if target_info["type"] == "local_code" and "target_path" in details:
|
|
local_sources.append(
|
|
{
|
|
"source_path": details["target_path"],
|
|
"workspace_subdir": workspace_subdir,
|
|
}
|
|
)
|
|
|
|
elif target_info["type"] == "repository" and "cloned_repo_path" in details:
|
|
local_sources.append(
|
|
{
|
|
"source_path": details["cloned_repo_path"],
|
|
"workspace_subdir": workspace_subdir,
|
|
}
|
|
)
|
|
|
|
return local_sources
|
|
|
|
|
|
def _is_localhost_host(host: str) -> bool:
|
|
host_lower = host.lower().strip("[]")
|
|
|
|
if host_lower in ("localhost", "0.0.0.0", "::1"): # nosec B104
|
|
return True
|
|
|
|
try:
|
|
ip = ipaddress.ip_address(host_lower)
|
|
if isinstance(ip, ipaddress.IPv4Address):
|
|
return ip.is_loopback # 127.0.0.0/8
|
|
if isinstance(ip, ipaddress.IPv6Address):
|
|
return ip.is_loopback # ::1
|
|
except ValueError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
def rewrite_localhost_targets(targets_info: list[dict[str, Any]], host_gateway: str) -> None:
|
|
from yarl import URL # type: ignore[import-not-found]
|
|
|
|
for target_info in targets_info:
|
|
target_type = target_info.get("type")
|
|
details = target_info.get("details", {})
|
|
|
|
if target_type == "web_application":
|
|
target_url = details.get("target_url", "")
|
|
try:
|
|
url = URL(target_url)
|
|
except (ValueError, TypeError):
|
|
continue
|
|
|
|
if url.host and _is_localhost_host(url.host):
|
|
details["target_url"] = str(url.with_host(host_gateway))
|
|
|
|
elif target_type == "ip_address":
|
|
target_ip = details.get("target_ip", "")
|
|
if target_ip and _is_localhost_host(target_ip):
|
|
details["target_ip"] = host_gateway
|
|
|
|
|
|
# Repository utilities
|
|
def clone_repository(repo_url: str, run_name: str, dest_name: str | None = None) -> str:
|
|
console = Console()
|
|
|
|
git_executable = shutil.which("git")
|
|
if git_executable is None:
|
|
raise FileNotFoundError("Git executable not found in PATH")
|
|
|
|
temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name
|
|
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if dest_name:
|
|
repo_name = dest_name
|
|
else:
|
|
repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name
|
|
|
|
clone_path = temp_dir / repo_name
|
|
|
|
if clone_path.exists():
|
|
shutil.rmtree(clone_path)
|
|
|
|
try:
|
|
with console.status(f"[bold cyan]Cloning repository {repo_url}...", spinner="dots"):
|
|
subprocess.run( # noqa: S603
|
|
[
|
|
git_executable,
|
|
"clone",
|
|
repo_url,
|
|
str(clone_path),
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
check=True,
|
|
)
|
|
|
|
return str(clone_path.absolute())
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
error_text = Text()
|
|
error_text.append("REPOSITORY CLONE FAILED", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append(f"Could not clone repository: {repo_url}\n", style="white")
|
|
error_text.append(
|
|
f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red"
|
|
)
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold white]STRIX",
|
|
title_align="left",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
sys.exit(1)
|
|
except FileNotFoundError:
|
|
error_text = Text()
|
|
error_text.append("GIT NOT FOUND", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append("Git is not installed or not available in PATH.\n", style="white")
|
|
error_text.append("Please install Git to clone repositories.\n", style="white")
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold white]STRIX",
|
|
title_align="left",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
sys.exit(1)
|
|
|
|
|
|
# Docker utilities
|
|
def check_docker_connection() -> Any:
|
|
try:
|
|
return docker.from_env()
|
|
except DockerException:
|
|
console = Console()
|
|
error_text = Text()
|
|
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append("Cannot connect to Docker daemon.\n", style="white")
|
|
error_text.append(
|
|
"Please ensure Docker Desktop is installed and running, and try running strix again.\n",
|
|
style="white",
|
|
)
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold white]STRIX",
|
|
title_align="left",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print("\n", panel, "\n")
|
|
raise RuntimeError("Docker not available") from None
|
|
|
|
|
|
def image_exists(client: Any, image_name: str) -> bool:
|
|
try:
|
|
client.images.get(image_name)
|
|
except ImageNotFound:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None:
|
|
if "Pull complete" in layer_status or "Already exists" in layer_status:
|
|
layers_info[layer_id] = "✓"
|
|
elif "Downloading" in layer_status:
|
|
layers_info[layer_id] = "↓"
|
|
elif "Extracting" in layer_status:
|
|
layers_info[layer_id] = "📦"
|
|
elif "Waiting" in layer_status:
|
|
layers_info[layer_id] = "⏳"
|
|
else:
|
|
layers_info[layer_id] = "•"
|
|
|
|
|
|
def process_pull_line(
|
|
line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str
|
|
) -> str:
|
|
if "id" in line and "status" in line:
|
|
layer_id = line["id"]
|
|
update_layer_status(layers_info, layer_id, line["status"])
|
|
|
|
completed = sum(1 for v in layers_info.values() if v == "✓")
|
|
total = len(layers_info)
|
|
|
|
if total > 0:
|
|
update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete"
|
|
if update_msg != last_update:
|
|
status.update(update_msg)
|
|
return update_msg
|
|
|
|
elif "status" in line and "id" not in line:
|
|
global_status = line["status"]
|
|
if "Pulling from" in global_status:
|
|
status.update("[bold cyan]Fetching image manifest...")
|
|
elif "Digest:" in global_status:
|
|
status.update("[bold cyan]Verifying image...")
|
|
elif "Status:" in global_status:
|
|
status.update("[bold cyan]Finalizing...")
|
|
|
|
return last_update
|
|
|
|
|
|
# LLM utilities
|
|
def validate_llm_response(response: Any) -> None:
|
|
if not response or not response.choices or not response.choices[0].message.content:
|
|
raise RuntimeError("Invalid response from LLM")
|
|
|
|
|
|
def validate_config_file(config_path: str) -> Path:
|
|
console = Console()
|
|
path = Path(config_path)
|
|
|
|
if not path.exists():
|
|
console.print(f"[bold red]Error:[/] Config file not found: {config_path}")
|
|
sys.exit(1)
|
|
|
|
if path.suffix != ".json":
|
|
console.print("[bold red]Error:[/] Config file must be a .json file")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
with path.open("r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
console.print(f"[bold red]Error:[/] Invalid JSON in config file: {e}")
|
|
sys.exit(1)
|
|
|
|
if not isinstance(data, dict):
|
|
console.print("[bold red]Error:[/] Config file must contain a JSON object")
|
|
sys.exit(1)
|
|
|
|
if "env" not in data or not isinstance(data.get("env"), dict):
|
|
console.print("[bold red]Error:[/] Config file must have an 'env' object")
|
|
sys.exit(1)
|
|
|
|
return path
|