feat: implement multi-target scanning
This commit is contained in:
@@ -19,55 +19,64 @@ class StrixAgent(BaseAgent):
|
||||
super().__init__(config)
|
||||
|
||||
async def execute_scan(self, scan_config: dict[str, Any]) -> dict[str, Any]:
|
||||
scan_type = scan_config.get("scan_type", "general")
|
||||
target = scan_config.get("target", {})
|
||||
user_instructions = scan_config.get("user_instructions", "")
|
||||
targets = scan_config.get("targets", [])
|
||||
|
||||
repositories = []
|
||||
local_code = []
|
||||
urls = []
|
||||
|
||||
for target in targets:
|
||||
target_type = target["type"]
|
||||
details = target["details"]
|
||||
workspace_subdir = details.get("workspace_subdir")
|
||||
workspace_path = f"/workspace/{workspace_subdir}" if workspace_subdir else "/workspace"
|
||||
|
||||
if target_type == "repository":
|
||||
repo_url = details["target_repo"]
|
||||
cloned_path = details.get("cloned_repo_path")
|
||||
repositories.append(
|
||||
{
|
||||
"url": repo_url,
|
||||
"workspace_path": workspace_path if cloned_path else None,
|
||||
}
|
||||
)
|
||||
|
||||
elif target_type == "local_code":
|
||||
original_path = details.get("target_path", "unknown")
|
||||
local_code.append(
|
||||
{
|
||||
"path": original_path,
|
||||
"workspace_path": workspace_path,
|
||||
}
|
||||
)
|
||||
|
||||
elif target_type == "web_application":
|
||||
urls.append(details["target_url"])
|
||||
|
||||
task_parts = []
|
||||
|
||||
if scan_type == "repository":
|
||||
repo_url = target["target_repo"]
|
||||
cloned_path = target.get("cloned_repo_path")
|
||||
if repositories:
|
||||
task_parts.append("\n\nRepositories:")
|
||||
for repo in repositories:
|
||||
if repo["workspace_path"]:
|
||||
task_parts.append(f"- {repo['url']} (available at: {repo['workspace_path']})")
|
||||
else:
|
||||
task_parts.append(f"- {repo['url']}")
|
||||
|
||||
if cloned_path:
|
||||
workspace_path = "/workspace"
|
||||
task_parts.append(
|
||||
f"Perform a security assessment of the Git repository: {repo_url}. "
|
||||
f"The repository has been cloned from '{repo_url}' to '{cloned_path}' "
|
||||
f"(host path) and then copied to '{workspace_path}' in your environment."
|
||||
f"Analyze the codebase at: {workspace_path}"
|
||||
)
|
||||
else:
|
||||
task_parts.append(
|
||||
f"Perform a security assessment of the Git repository: {repo_url}"
|
||||
)
|
||||
|
||||
elif scan_type == "web_application":
|
||||
task_parts.append(
|
||||
f"Perform a security assessment of the web application: {target['target_url']}"
|
||||
if local_code:
|
||||
task_parts.append("\n\nLocal Codebases:")
|
||||
task_parts.extend(
|
||||
f"- {code['path']} (available at: {code['workspace_path']})" for code in local_code
|
||||
)
|
||||
|
||||
elif scan_type == "local_code":
|
||||
original_path = target.get("target_path", "unknown")
|
||||
workspace_path = "/workspace"
|
||||
task_parts.append(
|
||||
f"Perform a security assessment of the local codebase. "
|
||||
f"The code from '{original_path}' (user host path) has been copied to "
|
||||
f"'{workspace_path}' in your environment. "
|
||||
f"Analyze the codebase at: {workspace_path}"
|
||||
)
|
||||
|
||||
else:
|
||||
task_parts.append(
|
||||
f"Perform a general security assessment of: {next(iter(target.values()))}"
|
||||
)
|
||||
if urls:
|
||||
task_parts.append("\n\nURLs:")
|
||||
task_parts.extend(f"- {url}" for url in urls)
|
||||
|
||||
task_description = " ".join(task_parts)
|
||||
|
||||
if user_instructions:
|
||||
task_description += (
|
||||
f"\n\nSpecial instructions from the system that must be followed: "
|
||||
f"{user_instructions}"
|
||||
)
|
||||
task_description += f"\n\nSpecial instructions: {user_instructions}"
|
||||
|
||||
return await self.agent_loop(task=task_description)
|
||||
|
||||
@@ -54,7 +54,7 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
self.config = config
|
||||
|
||||
self.local_source_path = config.get("local_source_path")
|
||||
self.local_sources = config.get("local_sources", [])
|
||||
self.non_interactive = config.get("non_interactive", False)
|
||||
|
||||
if "max_iterations" in config:
|
||||
@@ -317,7 +317,7 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
|
||||
runtime = get_runtime()
|
||||
sandbox_info = await runtime.create_sandbox(
|
||||
self.state.agent_id, self.state.sandbox_token, self.local_source_path
|
||||
self.state.agent_id, self.state.sandbox_token, self.local_sources
|
||||
)
|
||||
self.state.sandbox_id = sandbox_info["workspace_id"]
|
||||
self.state.sandbox_token = sandbox_info["auth_token"]
|
||||
|
||||
@@ -11,6 +11,8 @@ from strix.agents.StrixAgent import StrixAgent
|
||||
from strix.llm.config import LLMConfig
|
||||
from strix.telemetry.tracer import Tracer, set_global_tracer
|
||||
|
||||
from .utils import get_severity_color
|
||||
|
||||
|
||||
async def run_cli(args: Any) -> None: # noqa: PLR0915
|
||||
console = Console()
|
||||
@@ -19,15 +21,18 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
||||
start_text.append("🦉 ", style="bold white")
|
||||
start_text.append("STRIX CYBERSECURITY AGENT", style="bold green")
|
||||
|
||||
target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target
|
||||
target_text = Text()
|
||||
target_text.append("🎯 Target: ", style="bold cyan")
|
||||
target_text.append(str(target_value), style="bold white")
|
||||
|
||||
instructions_text = Text()
|
||||
if args.instruction:
|
||||
instructions_text.append("📋 Instructions: ", style="bold cyan")
|
||||
instructions_text.append(args.instruction, style="white")
|
||||
if len(args.targets_info) == 1:
|
||||
target_text.append("🎯 Target: ", style="bold cyan")
|
||||
target_text.append(args.targets_info[0]["original"], style="bold white")
|
||||
else:
|
||||
target_text.append("🎯 Targets: ", style="bold cyan")
|
||||
target_text.append(f"{len(args.targets_info)} targets\n", style="bold white")
|
||||
for i, target_info in enumerate(args.targets_info):
|
||||
target_text.append(" • ", style="dim white")
|
||||
target_text.append(target_info["original"], style="white")
|
||||
if i < len(args.targets_info) - 1:
|
||||
target_text.append("\n")
|
||||
|
||||
results_text = Text()
|
||||
results_text.append("📊 Results will be saved to: ", style="bold cyan")
|
||||
@@ -44,8 +49,6 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
||||
start_text,
|
||||
"\n\n",
|
||||
target_text,
|
||||
"\n" if args.instruction else "",
|
||||
instructions_text if args.instruction else "",
|
||||
"\n",
|
||||
results_text,
|
||||
note_text,
|
||||
@@ -62,8 +65,7 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
||||
|
||||
scan_config = {
|
||||
"scan_id": args.run_name,
|
||||
"scan_type": args.target_type,
|
||||
"target": args.target_dict,
|
||||
"targets": args.targets_info,
|
||||
"user_instructions": args.instruction or "",
|
||||
"run_name": args.run_name,
|
||||
}
|
||||
@@ -75,23 +77,14 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
||||
"non_interactive": True,
|
||||
}
|
||||
|
||||
if args.target_type == "local_code" and "target_path" in args.target_dict:
|
||||
agent_config["local_source_path"] = args.target_dict["target_path"]
|
||||
elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict:
|
||||
agent_config["local_source_path"] = args.target_dict["cloned_repo_path"]
|
||||
if getattr(args, "local_sources", None):
|
||||
agent_config["local_sources"] = args.local_sources
|
||||
|
||||
tracer = Tracer(args.run_name)
|
||||
tracer.set_scan_config(scan_config)
|
||||
|
||||
def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None:
|
||||
severity_colors = {
|
||||
"critical": "#dc2626",
|
||||
"high": "#ea580c",
|
||||
"medium": "#d97706",
|
||||
"low": "#65a30d",
|
||||
"info": "#0284c7",
|
||||
}
|
||||
severity_color = severity_colors.get(severity.lower(), "#6b7280")
|
||||
severity_color = get_severity_color(severity.lower())
|
||||
|
||||
vuln_text = Text()
|
||||
vuln_text.append("🐞 ", style="bold red")
|
||||
|
||||
@@ -7,16 +7,10 @@ import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import docker
|
||||
import litellm
|
||||
from docker.errors import DockerException
|
||||
from rich.console import Console
|
||||
@@ -25,6 +19,19 @@ from rich.text import Text
|
||||
|
||||
from strix.interface.cli import run_cli
|
||||
from strix.interface.tui import run_tui
|
||||
from strix.interface.utils import (
|
||||
assign_workspace_subdirs,
|
||||
build_llm_stats_text,
|
||||
build_stats_text,
|
||||
check_docker_connection,
|
||||
clone_repository,
|
||||
collect_local_sources,
|
||||
generate_run_name,
|
||||
image_exists,
|
||||
infer_target_type,
|
||||
process_pull_line,
|
||||
validate_llm_response,
|
||||
)
|
||||
from strix.runtime.docker_runtime import STRIX_IMAGE
|
||||
from strix.telemetry.tracer import get_global_tracer
|
||||
|
||||
@@ -32,15 +39,6 @@ from strix.telemetry.tracer import get_global_tracer
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def format_token_count(count: float) -> str:
|
||||
count = int(count)
|
||||
if count >= 1_000_000:
|
||||
return f"{count / 1_000_000:.1f}M"
|
||||
if count >= 1_000:
|
||||
return f"{count / 1_000:.1f}K"
|
||||
return str(count)
|
||||
|
||||
|
||||
def validate_environment() -> None: # noqa: PLR0912, PLR0915
|
||||
console = Console()
|
||||
missing_required_vars = []
|
||||
@@ -163,11 +161,6 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _validate_llm_response(response: Any) -> None:
|
||||
if not response or not response.choices or not response.choices[0].message.content:
|
||||
raise RuntimeError("Invalid response from LLM")
|
||||
|
||||
|
||||
def check_docker_installed() -> None:
|
||||
if shutil.which("docker") is None:
|
||||
console = Console()
|
||||
@@ -220,7 +213,7 @@ async def warm_up_llm() -> None:
|
||||
messages=test_messages,
|
||||
)
|
||||
|
||||
_validate_llm_response(response)
|
||||
validate_llm_response(response)
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
error_text = Text()
|
||||
@@ -245,141 +238,6 @@ async def warm_up_llm() -> None:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def generate_run_name() -> str:
|
||||
# fmt: off
|
||||
adjectives = [
|
||||
"stealthy", "sneaky", "crafty", "elite", "phantom", "shadow", "silent",
|
||||
"rogue", "covert", "ninja", "ghost", "cyber", "digital", "binary",
|
||||
"encrypted", "obfuscated", "masked", "cloaked", "invisible", "anonymous"
|
||||
]
|
||||
nouns = [
|
||||
"exploit", "payload", "backdoor", "rootkit", "keylogger", "botnet", "trojan",
|
||||
"worm", "virus", "packet", "buffer", "shell", "daemon", "spider", "crawler",
|
||||
"scanner", "sniffer", "honeypot", "firewall", "breach"
|
||||
]
|
||||
# fmt: on
|
||||
adj = secrets.choice(adjectives)
|
||||
noun = secrets.choice(nouns)
|
||||
number = secrets.randbelow(900) + 100
|
||||
return f"{adj}-{noun}-{number}"
|
||||
|
||||
|
||||
def clone_repository(repo_url: str, run_name: str) -> str:
|
||||
console = Console()
|
||||
|
||||
git_executable = shutil.which("git")
|
||||
if git_executable is None:
|
||||
raise FileNotFoundError("Git executable not found in PATH")
|
||||
|
||||
temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name
|
||||
|
||||
clone_path = temp_dir / repo_name
|
||||
|
||||
if clone_path.exists():
|
||||
shutil.rmtree(clone_path)
|
||||
|
||||
try:
|
||||
with console.status(f"[bold cyan]Cloning repository {repo_name}...", spinner="dots"):
|
||||
subprocess.run( # noqa: S603
|
||||
[
|
||||
git_executable,
|
||||
"clone",
|
||||
repo_url,
|
||||
str(clone_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
return str(clone_path.absolute())
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_text = Text()
|
||||
error_text.append("❌ ", style="bold red")
|
||||
error_text.append("REPOSITORY CLONE FAILED", style="bold red")
|
||||
error_text.append("\n\n", style="white")
|
||||
error_text.append(f"Could not clone repository: {repo_url}\n", style="white")
|
||||
error_text.append(
|
||||
f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red"
|
||||
)
|
||||
|
||||
panel = Panel(
|
||||
error_text,
|
||||
title="[bold red]🛡️ STRIX CLONE ERROR",
|
||||
title_align="center",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
sys.exit(1)
|
||||
except FileNotFoundError:
|
||||
error_text = Text()
|
||||
error_text.append("❌ ", style="bold red")
|
||||
error_text.append("GIT NOT FOUND", style="bold red")
|
||||
error_text.append("\n\n", style="white")
|
||||
error_text.append("Git is not installed or not available in PATH.\n", style="white")
|
||||
error_text.append("Please install Git to clone repositories.\n", style="white")
|
||||
|
||||
panel = Panel(
|
||||
error_text,
|
||||
title="[bold red]🛡️ STRIX CLONE ERROR",
|
||||
title_align="center",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def infer_target_type(target: str) -> tuple[str, dict[str, str]]:
|
||||
if not target or not isinstance(target, str):
|
||||
raise ValueError("Target must be a non-empty string")
|
||||
|
||||
target = target.strip()
|
||||
|
||||
parsed = urlparse(target)
|
||||
if parsed.scheme in ("http", "https"):
|
||||
if any(
|
||||
host in parsed.netloc.lower() for host in ["github.com", "gitlab.com", "bitbucket.org"]
|
||||
):
|
||||
return "repository", {"target_repo": target}
|
||||
return "web_application", {"target_url": target}
|
||||
|
||||
path = Path(target)
|
||||
try:
|
||||
if path.exists():
|
||||
if path.is_dir():
|
||||
return "local_code", {"target_path": str(path.absolute())}
|
||||
raise ValueError(f"Path exists but is not a directory: {target}")
|
||||
except (OSError, RuntimeError) as e:
|
||||
raise ValueError(f"Invalid path: {target} - {e!s}") from e
|
||||
|
||||
if target.startswith("git@") or target.endswith(".git"):
|
||||
return "repository", {"target_repo": target}
|
||||
|
||||
if "." in target and "/" not in target and not target.startswith("."):
|
||||
parts = target.split(".")
|
||||
if len(parts) >= 2 and all(p and p.strip() for p in parts):
|
||||
return "web_application", {"target_url": f"https://{target}"}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid target: {target}\n"
|
||||
"Target must be one of:\n"
|
||||
"- A valid URL (http:// or https://)\n"
|
||||
"- A Git repository URL (https://github.com/... or git@github.com:...)\n"
|
||||
"- A local directory path\n"
|
||||
"- A domain name (e.g., example.com)"
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Strix Multi-Agent Cybersecurity Penetration Testing Tool",
|
||||
@@ -399,16 +257,23 @@ Examples:
|
||||
# Domain penetration test
|
||||
strix --target example.com
|
||||
|
||||
# Multiple targets (e.g., white-box testing with source and deployed app)
|
||||
strix --target https://github.com/user/repo --target https://example.com
|
||||
strix --target ./my-project --target https://staging.example.com --target https://prod.example.com
|
||||
|
||||
# Custom instructions
|
||||
strix --target example.com --instruction "Focus on authentication vulnerabilities"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--target",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target to test (URL, repository, local directory path, or domain name)",
|
||||
action="append",
|
||||
help="Target to test (URL, repository, local directory path, or domain name). "
|
||||
"Can be specified multiple times for multi-target scans.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruction",
|
||||
@@ -439,127 +304,53 @@ Examples:
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
args.target_type, args.target_dict = infer_target_type(args.target)
|
||||
except ValueError as e:
|
||||
parser.error(str(e))
|
||||
args.targets_info = []
|
||||
for target in args.target:
|
||||
try:
|
||||
target_type, target_dict = infer_target_type(target)
|
||||
|
||||
if target_type == "local_code":
|
||||
display_target = target_dict.get("target_path", target)
|
||||
else:
|
||||
display_target = target
|
||||
|
||||
args.targets_info.append(
|
||||
{"type": target_type, "details": target_dict, "original": display_target}
|
||||
)
|
||||
except ValueError:
|
||||
parser.error(f"Invalid target '{target}'")
|
||||
|
||||
assign_workspace_subdirs(args.targets_info)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def _get_severity_color(severity: str) -> str:
|
||||
severity_colors = {
|
||||
"critical": "#dc2626",
|
||||
"high": "#ea580c",
|
||||
"medium": "#d97706",
|
||||
"low": "#65a30d",
|
||||
"info": "#0284c7",
|
||||
}
|
||||
return severity_colors.get(severity, "#6b7280")
|
||||
|
||||
|
||||
def _build_stats_text(tracer: Any) -> Text:
|
||||
stats_text = Text()
|
||||
if not tracer:
|
||||
return stats_text
|
||||
|
||||
vuln_count = len(tracer.vulnerability_reports)
|
||||
tool_count = tracer.get_real_tool_count()
|
||||
agent_count = len(tracer.agents)
|
||||
|
||||
if vuln_count > 0:
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||
for report in tracer.vulnerability_reports:
|
||||
severity = report.get("severity", "").lower()
|
||||
if severity in severity_counts:
|
||||
severity_counts[severity] += 1
|
||||
|
||||
stats_text.append("🔍 Vulnerabilities Found: ", style="bold red")
|
||||
|
||||
severity_parts = []
|
||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||
count = severity_counts[severity]
|
||||
if count > 0:
|
||||
severity_color = _get_severity_color(severity)
|
||||
severity_text = Text()
|
||||
severity_text.append(f"{severity.upper()}: ", style=severity_color)
|
||||
severity_text.append(str(count), style=f"bold {severity_color}")
|
||||
severity_parts.append(severity_text)
|
||||
|
||||
for i, part in enumerate(severity_parts):
|
||||
stats_text.append(part)
|
||||
if i < len(severity_parts) - 1:
|
||||
stats_text.append(" | ", style="dim white")
|
||||
|
||||
stats_text.append(" (Total: ", style="dim white")
|
||||
stats_text.append(str(vuln_count), style="bold yellow")
|
||||
stats_text.append(")", style="dim white")
|
||||
stats_text.append("\n")
|
||||
else:
|
||||
stats_text.append("🔍 Vulnerabilities Found: ", style="bold green")
|
||||
stats_text.append("0", style="bold white")
|
||||
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
|
||||
stats_text.append("\n")
|
||||
|
||||
stats_text.append("🤖 Agents Used: ", style="bold cyan")
|
||||
stats_text.append(str(agent_count), style="bold white")
|
||||
stats_text.append(" • ", style="dim white")
|
||||
stats_text.append("🛠️ Tools Called: ", style="bold cyan")
|
||||
stats_text.append(str(tool_count), style="bold white")
|
||||
|
||||
return stats_text
|
||||
|
||||
|
||||
def _build_llm_stats_text(tracer: Any) -> Text:
|
||||
llm_stats_text = Text()
|
||||
if not tracer:
|
||||
return llm_stats_text
|
||||
|
||||
llm_stats = tracer.get_total_llm_stats()
|
||||
total_stats = llm_stats["total"]
|
||||
|
||||
if total_stats["requests"] > 0:
|
||||
llm_stats_text.append("📥 Input Tokens: ", style="bold cyan")
|
||||
llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
|
||||
|
||||
if total_stats["cached_tokens"] > 0:
|
||||
llm_stats_text.append(" • ", style="dim white")
|
||||
llm_stats_text.append("⚡ Cached: ", style="bold green")
|
||||
llm_stats_text.append(
|
||||
format_token_count(total_stats["cached_tokens"]), style="bold green"
|
||||
)
|
||||
|
||||
llm_stats_text.append(" • ", style="dim white")
|
||||
llm_stats_text.append("📤 Output Tokens: ", style="bold cyan")
|
||||
llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
|
||||
|
||||
if total_stats["cost"] > 0:
|
||||
llm_stats_text.append(" • ", style="dim white")
|
||||
llm_stats_text.append("💰 Total Cost: $", style="bold cyan")
|
||||
llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow")
|
||||
|
||||
return llm_stats_text
|
||||
|
||||
|
||||
def display_completion_message(args: argparse.Namespace, results_path: Path) -> None:
|
||||
console = Console()
|
||||
tracer = get_global_tracer()
|
||||
|
||||
target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target
|
||||
|
||||
completion_text = Text()
|
||||
completion_text.append("🦉 ", style="bold white")
|
||||
completion_text.append("AGENT FINISHED", style="bold green")
|
||||
completion_text.append(" • ", style="dim white")
|
||||
completion_text.append("Penetration test completed", style="white")
|
||||
|
||||
stats_text = _build_stats_text(tracer)
|
||||
stats_text = build_stats_text(tracer)
|
||||
|
||||
llm_stats_text = _build_llm_stats_text(tracer)
|
||||
llm_stats_text = build_llm_stats_text(tracer)
|
||||
|
||||
target_text = Text()
|
||||
target_text.append("🎯 Target: ", style="bold cyan")
|
||||
target_text.append(str(target_value), style="bold white")
|
||||
if len(args.targets_info) == 1:
|
||||
target_text.append("🎯 Target: ", style="bold cyan")
|
||||
target_text.append(args.targets_info[0]["original"], style="bold white")
|
||||
else:
|
||||
target_text.append("🎯 Targets: ", style="bold cyan")
|
||||
target_text.append(f"{len(args.targets_info)} targets\n", style="bold white")
|
||||
for i, target_info in enumerate(args.targets_info):
|
||||
target_text.append(" • ", style="dim white")
|
||||
target_text.append(target_info["original"], style="white")
|
||||
if i < len(args.targets_info) - 1:
|
||||
target_text.append("\n")
|
||||
|
||||
results_text = Text()
|
||||
results_text.append("📊 Results Saved To: ", style="bold cyan")
|
||||
@@ -575,19 +366,19 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
|
||||
stats_text,
|
||||
"\n",
|
||||
llm_stats_text,
|
||||
"\n",
|
||||
"\n\n",
|
||||
results_text,
|
||||
)
|
||||
else:
|
||||
panel_content = Text.assemble(
|
||||
completion_text, "\n\n", target_text, "\n", stats_text, "\n", results_text
|
||||
completion_text, "\n\n", target_text, "\n", stats_text, "\n\n", results_text
|
||||
)
|
||||
elif llm_stats_text.plain:
|
||||
panel_content = Text.assemble(
|
||||
completion_text, "\n\n", target_text, "\n", llm_stats_text, "\n", results_text
|
||||
completion_text, "\n\n", target_text, "\n", llm_stats_text, "\n\n", results_text
|
||||
)
|
||||
else:
|
||||
panel_content = Text.assemble(completion_text, "\n\n", target_text, "\n", results_text)
|
||||
panel_content = Text.assemble(completion_text, "\n\n", target_text, "\n\n", results_text)
|
||||
|
||||
panel = Panel(
|
||||
panel_content,
|
||||
@@ -602,86 +393,11 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
|
||||
console.print()
|
||||
|
||||
|
||||
def _check_docker_connection() -> Any:
|
||||
try:
|
||||
return docker.from_env()
|
||||
except DockerException:
|
||||
console = Console()
|
||||
error_text = Text()
|
||||
error_text.append("❌ ", style="bold red")
|
||||
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
|
||||
error_text.append("\n\n", style="white")
|
||||
error_text.append("Cannot connect to Docker daemon.\n", style="white")
|
||||
error_text.append("Please ensure Docker is installed and running.\n\n", style="white")
|
||||
error_text.append("Try running: ", style="dim white")
|
||||
error_text.append("sudo systemctl start docker", style="dim cyan")
|
||||
|
||||
panel = Panel(
|
||||
error_text,
|
||||
title="[bold red]🛡️ STRIX STARTUP ERROR",
|
||||
title_align="center",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n", panel, "\n")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _image_exists(client: Any) -> bool:
|
||||
try:
|
||||
client.images.get(STRIX_IMAGE)
|
||||
except docker.errors.ImageNotFound:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None:
|
||||
if "Pull complete" in layer_status or "Already exists" in layer_status:
|
||||
layers_info[layer_id] = "✓"
|
||||
elif "Downloading" in layer_status:
|
||||
layers_info[layer_id] = "↓"
|
||||
elif "Extracting" in layer_status:
|
||||
layers_info[layer_id] = "📦"
|
||||
elif "Waiting" in layer_status:
|
||||
layers_info[layer_id] = "⏳"
|
||||
else:
|
||||
layers_info[layer_id] = "•"
|
||||
|
||||
|
||||
def _process_pull_line(
|
||||
line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str
|
||||
) -> str:
|
||||
if "id" in line and "status" in line:
|
||||
layer_id = line["id"]
|
||||
_update_layer_status(layers_info, layer_id, line["status"])
|
||||
|
||||
completed = sum(1 for v in layers_info.values() if v == "✓")
|
||||
total = len(layers_info)
|
||||
|
||||
if total > 0:
|
||||
update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete"
|
||||
if update_msg != last_update:
|
||||
status.update(update_msg)
|
||||
return update_msg
|
||||
|
||||
elif "status" in line and "id" not in line:
|
||||
global_status = line["status"]
|
||||
if "Pulling from" in global_status:
|
||||
status.update("[bold cyan]Fetching image manifest...")
|
||||
elif "Digest:" in global_status:
|
||||
status.update("[bold cyan]Verifying image...")
|
||||
elif "Status:" in global_status:
|
||||
status.update("[bold cyan]Finalizing...")
|
||||
|
||||
return last_update
|
||||
|
||||
|
||||
def pull_docker_image() -> None:
|
||||
console = Console()
|
||||
client = _check_docker_connection()
|
||||
client = check_docker_connection()
|
||||
|
||||
if _image_exists(client):
|
||||
if image_exists(client, STRIX_IMAGE):
|
||||
return
|
||||
|
||||
console.print()
|
||||
@@ -695,7 +411,7 @@ def pull_docker_image() -> None:
|
||||
last_update = ""
|
||||
|
||||
for line in client.api.pull(STRIX_IMAGE, stream=True, decode=True):
|
||||
last_update = _process_pull_line(line, layers_info, status, last_update)
|
||||
last_update = process_pull_line(line, layers_info, status, last_update)
|
||||
|
||||
except DockerException as e:
|
||||
console.print()
|
||||
@@ -738,11 +454,14 @@ def main() -> None:
|
||||
if not args.run_name:
|
||||
args.run_name = generate_run_name()
|
||||
|
||||
if args.target_type == "repository":
|
||||
repo_url = args.target_dict["target_repo"]
|
||||
cloned_path = clone_repository(repo_url, args.run_name)
|
||||
for target_info in args.targets_info:
|
||||
if target_info["type"] == "repository":
|
||||
repo_url = target_info["details"]["target_repo"]
|
||||
dest_name = target_info["details"].get("workspace_subdir")
|
||||
cloned_path = clone_repository(repo_url, args.run_name, dest_name)
|
||||
target_info["details"]["cloned_repo_path"] = cloned_path
|
||||
|
||||
args.target_dict["cloned_repo_path"] = cloned_path
|
||||
args.local_sources = collect_local_sources(args.targets_info)
|
||||
|
||||
if args.non_interactive:
|
||||
asyncio.run(run_cli(args))
|
||||
|
||||
@@ -16,23 +16,28 @@ class ScanStartInfoRenderer(BaseToolRenderer):
|
||||
args = tool_data.get("args", {})
|
||||
status = tool_data.get("status", "unknown")
|
||||
|
||||
target = args.get("target", {})
|
||||
targets = args.get("targets", [])
|
||||
|
||||
target_display = cls._build_target_display(target)
|
||||
|
||||
content = f"🚀 Starting scan on {target_display}"
|
||||
if len(targets) == 1:
|
||||
target_display = cls._build_single_target_display(targets[0])
|
||||
content = f"🚀 Starting penetration test on {target_display}"
|
||||
elif len(targets) > 1:
|
||||
content = f"🚀 Starting penetration test on {len(targets)} targets"
|
||||
for target_info in targets:
|
||||
target_display = cls._build_single_target_display(target_info)
|
||||
content += f"\n • {target_display}"
|
||||
else:
|
||||
content = "🚀 Starting penetration test"
|
||||
|
||||
css_classes = cls.get_css_classes(status)
|
||||
return Static(content, classes=css_classes)
|
||||
|
||||
@classmethod
|
||||
def _build_target_display(cls, target: dict[str, Any]) -> str:
|
||||
if target_url := target.get("target_url"):
|
||||
return cls.escape_markup(str(target_url))
|
||||
if target_repo := target.get("target_repo"):
|
||||
return cls.escape_markup(str(target_repo))
|
||||
if target_path := target.get("target_path"):
|
||||
return cls.escape_markup(str(target_path))
|
||||
def _build_single_target_display(cls, target_info: dict[str, Any]) -> str:
|
||||
original = target_info.get("original")
|
||||
if original:
|
||||
return cls.escape_markup(str(original))
|
||||
|
||||
return "unknown target"
|
||||
|
||||
|
||||
|
||||
@@ -312,8 +312,7 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]:
|
||||
return {
|
||||
"scan_id": args.run_name,
|
||||
"scan_type": args.target_type,
|
||||
"target": args.target_dict,
|
||||
"targets": args.targets_info,
|
||||
"user_instructions": args.instruction or "",
|
||||
"run_name": args.run_name,
|
||||
}
|
||||
@@ -326,10 +325,8 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
"max_iterations": 300,
|
||||
}
|
||||
|
||||
if args.target_type == "local_code" and "target_path" in args.target_dict:
|
||||
config["local_source_path"] = args.target_dict["target_path"]
|
||||
elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict:
|
||||
config["local_source_path"] = args.target_dict["cloned_repo_path"]
|
||||
if getattr(args, "local_sources", None):
|
||||
config["local_sources"] = args.local_sources
|
||||
|
||||
return config
|
||||
|
||||
|
||||
434
strix/interface/utils.py
Normal file
434
strix/interface/utils.py
Normal file
@@ -0,0 +1,434 @@
|
||||
import re
|
||||
import secrets
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import docker
|
||||
from docker.errors import DockerException, ImageNotFound
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
# Token formatting utilities
|
||||
def format_token_count(count: float) -> str:
|
||||
count = int(count)
|
||||
if count >= 1_000_000:
|
||||
return f"{count / 1_000_000:.1f}M"
|
||||
if count >= 1_000:
|
||||
return f"{count / 1_000:.1f}K"
|
||||
return str(count)
|
||||
|
||||
|
||||
# Display utilities
|
||||
def get_severity_color(severity: str) -> str:
|
||||
severity_colors = {
|
||||
"critical": "#dc2626",
|
||||
"high": "#ea580c",
|
||||
"medium": "#d97706",
|
||||
"low": "#65a30d",
|
||||
"info": "#0284c7",
|
||||
}
|
||||
return severity_colors.get(severity, "#6b7280")
|
||||
|
||||
|
||||
def build_stats_text(tracer: Any) -> Text:
|
||||
stats_text = Text()
|
||||
if not tracer:
|
||||
return stats_text
|
||||
|
||||
vuln_count = len(tracer.vulnerability_reports)
|
||||
tool_count = tracer.get_real_tool_count()
|
||||
agent_count = len(tracer.agents)
|
||||
|
||||
if vuln_count > 0:
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||
for report in tracer.vulnerability_reports:
|
||||
severity = report.get("severity", "").lower()
|
||||
if severity in severity_counts:
|
||||
severity_counts[severity] += 1
|
||||
|
||||
stats_text.append("🔍 Vulnerabilities Found: ", style="bold red")
|
||||
|
||||
severity_parts = []
|
||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||
count = severity_counts[severity]
|
||||
if count > 0:
|
||||
severity_color = get_severity_color(severity)
|
||||
severity_text = Text()
|
||||
severity_text.append(f"{severity.upper()}: ", style=severity_color)
|
||||
severity_text.append(str(count), style=f"bold {severity_color}")
|
||||
severity_parts.append(severity_text)
|
||||
|
||||
for i, part in enumerate(severity_parts):
|
||||
stats_text.append(part)
|
||||
if i < len(severity_parts) - 1:
|
||||
stats_text.append(" | ", style="dim white")
|
||||
|
||||
stats_text.append(" (Total: ", style="dim white")
|
||||
stats_text.append(str(vuln_count), style="bold yellow")
|
||||
stats_text.append(")", style="dim white")
|
||||
stats_text.append("\n")
|
||||
else:
|
||||
stats_text.append("🔍 Vulnerabilities Found: ", style="bold green")
|
||||
stats_text.append("0", style="bold white")
|
||||
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
|
||||
stats_text.append("\n")
|
||||
|
||||
stats_text.append("🤖 Agents Used: ", style="bold cyan")
|
||||
stats_text.append(str(agent_count), style="bold white")
|
||||
stats_text.append(" • ", style="dim white")
|
||||
stats_text.append("🛠️ Tools Called: ", style="bold cyan")
|
||||
stats_text.append(str(tool_count), style="bold white")
|
||||
|
||||
return stats_text
|
||||
|
||||
|
||||
def build_llm_stats_text(tracer: Any) -> Text:
|
||||
llm_stats_text = Text()
|
||||
if not tracer:
|
||||
return llm_stats_text
|
||||
|
||||
llm_stats = tracer.get_total_llm_stats()
|
||||
total_stats = llm_stats["total"]
|
||||
|
||||
if total_stats["requests"] > 0:
|
||||
llm_stats_text.append("📥 Input Tokens: ", style="bold cyan")
|
||||
llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
|
||||
|
||||
if total_stats["cached_tokens"] > 0:
|
||||
llm_stats_text.append(" • ", style="dim white")
|
||||
llm_stats_text.append("⚡ Cached: ", style="bold green")
|
||||
llm_stats_text.append(
|
||||
format_token_count(total_stats["cached_tokens"]), style="bold green"
|
||||
)
|
||||
|
||||
llm_stats_text.append(" • ", style="dim white")
|
||||
llm_stats_text.append("📤 Output Tokens: ", style="bold cyan")
|
||||
llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
|
||||
|
||||
if total_stats["cost"] > 0:
|
||||
llm_stats_text.append(" • ", style="dim white")
|
||||
llm_stats_text.append("💰 Total Cost: $", style="bold cyan")
|
||||
llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow")
|
||||
|
||||
return llm_stats_text
|
||||
|
||||
|
||||
# Name generation utilities
|
||||
def generate_run_name() -> str:
|
||||
# fmt: off
|
||||
adjectives = [
|
||||
"stealthy", "sneaky", "crafty", "elite", "phantom", "shadow", "silent",
|
||||
"rogue", "covert", "ninja", "ghost", "cyber", "digital", "binary",
|
||||
"encrypted", "obfuscated", "masked", "cloaked", "invisible", "anonymous"
|
||||
]
|
||||
nouns = [
|
||||
"exploit", "payload", "backdoor", "rootkit", "keylogger", "botnet", "trojan",
|
||||
"worm", "virus", "packet", "buffer", "shell", "daemon", "spider", "crawler",
|
||||
"scanner", "sniffer", "honeypot", "firewall", "breach"
|
||||
]
|
||||
# fmt: on
|
||||
adj = secrets.choice(adjectives)
|
||||
noun = secrets.choice(nouns)
|
||||
number = secrets.randbelow(900) + 100
|
||||
return f"{adj}-{noun}-{number}"
|
||||
|
||||
|
||||
# Target processing utilities
|
||||
def infer_target_type(target: str) -> tuple[str, dict[str, str]]:
|
||||
if not target or not isinstance(target, str):
|
||||
raise ValueError("Target must be a non-empty string")
|
||||
|
||||
target = target.strip()
|
||||
|
||||
lower_target = target.lower()
|
||||
bare_repo_prefixes = (
|
||||
"github.com/",
|
||||
"www.github.com/",
|
||||
"gitlab.com/",
|
||||
"www.gitlab.com/",
|
||||
"bitbucket.org/",
|
||||
"www.bitbucket.org/",
|
||||
)
|
||||
if any(lower_target.startswith(p) for p in bare_repo_prefixes):
|
||||
return "repository", {"target_repo": f"https://{target}"}
|
||||
|
||||
parsed = urlparse(target)
|
||||
if parsed.scheme in ("http", "https"):
|
||||
if any(
|
||||
host in parsed.netloc.lower() for host in ["github.com", "gitlab.com", "bitbucket.org"]
|
||||
):
|
||||
return "repository", {"target_repo": target}
|
||||
return "web_application", {"target_url": target}
|
||||
|
||||
path = Path(target).expanduser()
|
||||
try:
|
||||
if path.exists():
|
||||
if path.is_dir():
|
||||
resolved = path.resolve()
|
||||
return "local_code", {"target_path": str(resolved)}
|
||||
raise ValueError(f"Path exists but is not a directory: {target}")
|
||||
except (OSError, RuntimeError) as e:
|
||||
raise ValueError(f"Invalid path: {target} - {e!s}") from e
|
||||
|
||||
if target.startswith("git@") or target.endswith(".git"):
|
||||
return "repository", {"target_repo": target}
|
||||
|
||||
if "." in target and "/" not in target and not target.startswith("."):
|
||||
parts = target.split(".")
|
||||
if len(parts) >= 2 and all(p and p.strip() for p in parts):
|
||||
return "web_application", {"target_url": f"https://{target}"}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid target: {target}\n"
|
||||
"Target must be one of:\n"
|
||||
"- A valid URL (http:// or https://)\n"
|
||||
"- A Git repository URL (https://github.com/... or git@github.com:...)\n"
|
||||
"- A local directory path\n"
|
||||
"- A domain name (e.g., example.com)"
|
||||
)
|
||||
|
||||
|
||||
def sanitize_name(name: str) -> str:
|
||||
sanitized = re.sub(r"[^A-Za-z0-9._-]", "-", name.strip())
|
||||
return sanitized or "target"
|
||||
|
||||
|
||||
def derive_repo_base_name(repo_url: str) -> str:
|
||||
if repo_url.endswith("/"):
|
||||
repo_url = repo_url[:-1]
|
||||
|
||||
if ":" in repo_url and repo_url.startswith("git@"):
|
||||
path_part = repo_url.split(":", 1)[1]
|
||||
else:
|
||||
path_part = urlparse(repo_url).path or repo_url
|
||||
|
||||
candidate = path_part.split("/")[-1]
|
||||
if candidate.endswith(".git"):
|
||||
candidate = candidate[:-4]
|
||||
|
||||
return sanitize_name(candidate or "repository")
|
||||
|
||||
|
||||
def derive_local_base_name(path_str: str) -> str:
|
||||
try:
|
||||
base = Path(path_str).resolve().name
|
||||
except (OSError, RuntimeError):
|
||||
base = Path(path_str).name
|
||||
return sanitize_name(base or "workspace")
|
||||
|
||||
|
||||
def assign_workspace_subdirs(targets_info: list[dict[str, Any]]) -> None:
|
||||
name_counts: dict[str, int] = {}
|
||||
|
||||
for target in targets_info:
|
||||
target_type = target["type"]
|
||||
details = target["details"]
|
||||
|
||||
base_name: str | None = None
|
||||
if target_type == "repository":
|
||||
base_name = derive_repo_base_name(details["target_repo"])
|
||||
elif target_type == "local_code":
|
||||
base_name = derive_local_base_name(details.get("target_path", "local"))
|
||||
|
||||
if base_name is None:
|
||||
continue
|
||||
|
||||
count = name_counts.get(base_name, 0) + 1
|
||||
name_counts[base_name] = count
|
||||
|
||||
workspace_subdir = base_name if count == 1 else f"{base_name}-{count}"
|
||||
|
||||
details["workspace_subdir"] = workspace_subdir
|
||||
|
||||
|
||||
def collect_local_sources(targets_info: list[dict[str, Any]]) -> list[dict[str, str]]:
|
||||
local_sources: list[dict[str, str]] = []
|
||||
|
||||
for target_info in targets_info:
|
||||
details = target_info["details"]
|
||||
workspace_subdir = details.get("workspace_subdir")
|
||||
|
||||
if target_info["type"] == "local_code" and "target_path" in details:
|
||||
local_sources.append(
|
||||
{
|
||||
"source_path": details["target_path"],
|
||||
"workspace_subdir": workspace_subdir,
|
||||
}
|
||||
)
|
||||
|
||||
elif target_info["type"] == "repository" and "cloned_repo_path" in details:
|
||||
local_sources.append(
|
||||
{
|
||||
"source_path": details["cloned_repo_path"],
|
||||
"workspace_subdir": workspace_subdir,
|
||||
}
|
||||
)
|
||||
|
||||
return local_sources
|
||||
|
||||
|
||||
# Repository utilities
|
||||
def clone_repository(repo_url: str, run_name: str, dest_name: str | None = None) -> str:
|
||||
console = Console()
|
||||
|
||||
git_executable = shutil.which("git")
|
||||
if git_executable is None:
|
||||
raise FileNotFoundError("Git executable not found in PATH")
|
||||
|
||||
temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dest_name:
|
||||
repo_name = dest_name
|
||||
else:
|
||||
repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name
|
||||
|
||||
clone_path = temp_dir / repo_name
|
||||
|
||||
if clone_path.exists():
|
||||
shutil.rmtree(clone_path)
|
||||
|
||||
try:
|
||||
with console.status(f"[bold cyan]Cloning repository {repo_url}...", spinner="dots"):
|
||||
subprocess.run( # noqa: S603
|
||||
[
|
||||
git_executable,
|
||||
"clone",
|
||||
repo_url,
|
||||
str(clone_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
return str(clone_path.absolute())
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_text = Text()
|
||||
error_text.append("❌ ", style="bold red")
|
||||
error_text.append("REPOSITORY CLONE FAILED", style="bold red")
|
||||
error_text.append("\n\n", style="white")
|
||||
error_text.append(f"Could not clone repository: {repo_url}\n", style="white")
|
||||
error_text.append(
|
||||
f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red"
|
||||
)
|
||||
|
||||
panel = Panel(
|
||||
error_text,
|
||||
title="[bold red]🛡️ STRIX CLONE ERROR",
|
||||
title_align="center",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
raise
|
||||
except FileNotFoundError:
|
||||
error_text = Text()
|
||||
error_text.append("❌ ", style="bold red")
|
||||
error_text.append("GIT NOT FOUND", style="bold red")
|
||||
error_text.append("\n\n", style="white")
|
||||
error_text.append("Git is not installed or not available in PATH.\n", style="white")
|
||||
error_text.append("Please install Git to clone repositories.\n", style="white")
|
||||
|
||||
panel = Panel(
|
||||
error_text,
|
||||
title="[bold red]🛡️ STRIX CLONE ERROR",
|
||||
title_align="center",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
raise
|
||||
|
||||
|
||||
# Docker utilities
|
||||
def check_docker_connection() -> Any:
|
||||
try:
|
||||
return docker.from_env()
|
||||
except DockerException:
|
||||
console = Console()
|
||||
error_text = Text()
|
||||
error_text.append("❌ ", style="bold red")
|
||||
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
|
||||
error_text.append("\n\n", style="white")
|
||||
error_text.append("Cannot connect to Docker daemon.\n", style="white")
|
||||
error_text.append("Please ensure Docker is installed and running.\n\n", style="white")
|
||||
error_text.append("Try running: ", style="dim white")
|
||||
error_text.append("sudo systemctl start docker", style="dim cyan")
|
||||
|
||||
panel = Panel(
|
||||
error_text,
|
||||
title="[bold red]🛡️ STRIX STARTUP ERROR",
|
||||
title_align="center",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n", panel, "\n")
|
||||
raise RuntimeError("Docker not available") from None
|
||||
|
||||
|
||||
def image_exists(client: Any, image_name: str) -> bool:
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
except ImageNotFound:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None:
|
||||
if "Pull complete" in layer_status or "Already exists" in layer_status:
|
||||
layers_info[layer_id] = "✓"
|
||||
elif "Downloading" in layer_status:
|
||||
layers_info[layer_id] = "↓"
|
||||
elif "Extracting" in layer_status:
|
||||
layers_info[layer_id] = "📦"
|
||||
elif "Waiting" in layer_status:
|
||||
layers_info[layer_id] = "⏳"
|
||||
else:
|
||||
layers_info[layer_id] = "•"
|
||||
|
||||
|
||||
def process_pull_line(
|
||||
line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str
|
||||
) -> str:
|
||||
if "id" in line and "status" in line:
|
||||
layer_id = line["id"]
|
||||
update_layer_status(layers_info, layer_id, line["status"])
|
||||
|
||||
completed = sum(1 for v in layers_info.values() if v == "✓")
|
||||
total = len(layers_info)
|
||||
|
||||
if total > 0:
|
||||
update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete"
|
||||
if update_msg != last_update:
|
||||
status.update(update_msg)
|
||||
return update_msg
|
||||
|
||||
elif "status" in line and "id" not in line:
|
||||
global_status = line["status"]
|
||||
if "Pulling from" in global_status:
|
||||
status.update("[bold cyan]Fetching image manifest...")
|
||||
elif "Digest:" in global_status:
|
||||
status.update("[bold cyan]Verifying image...")
|
||||
elif "Status:" in global_status:
|
||||
status.update("[bold cyan]Finalizing...")
|
||||
|
||||
return last_update
|
||||
|
||||
|
||||
# LLM utilities
|
||||
def validate_llm_response(response: Any) -> None:
|
||||
if not response or not response.choices or not response.choices[0].message.content:
|
||||
raise RuntimeError("Invalid response from LLM")
|
||||
@@ -250,7 +250,9 @@ class DockerRuntime(AbstractRuntime):
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
def _copy_local_directory_to_container(self, container: Container, local_path: str) -> None:
|
||||
def _copy_local_directory_to_container(
|
||||
self, container: Container, local_path: str, target_name: str | None = None
|
||||
) -> None:
|
||||
import tarfile
|
||||
from io import BytesIO
|
||||
|
||||
@@ -260,13 +262,20 @@ class DockerRuntime(AbstractRuntime):
|
||||
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
|
||||
return
|
||||
|
||||
logger.info(f"Copying local directory {local_path_obj} to container")
|
||||
if target_name:
|
||||
logger.info(
|
||||
f"Copying local directory {local_path_obj} to container at "
|
||||
f"/workspace/{target_name}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Copying local directory {local_path_obj} to container")
|
||||
|
||||
tar_buffer = BytesIO()
|
||||
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
|
||||
for item in local_path_obj.rglob("*"):
|
||||
if item.is_file():
|
||||
arcname = item.relative_to(local_path_obj)
|
||||
rel_path = item.relative_to(local_path_obj)
|
||||
arcname = Path(target_name) / rel_path if target_name else rel_path
|
||||
tar.add(item, arcname=arcname)
|
||||
|
||||
tar_buffer.seek(0)
|
||||
@@ -283,14 +292,26 @@ class DockerRuntime(AbstractRuntime):
|
||||
logger.exception("Failed to copy local directory to container")
|
||||
|
||||
async def create_sandbox(
|
||||
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None
|
||||
self,
|
||||
agent_id: str,
|
||||
existing_token: str | None = None,
|
||||
local_sources: list[dict[str, str]] | None = None,
|
||||
) -> SandboxInfo:
|
||||
scan_id = self._get_scan_id(agent_id)
|
||||
container = self._get_or_create_scan_container(scan_id)
|
||||
|
||||
source_copied_key = f"_source_copied_{scan_id}"
|
||||
if local_source_path and not hasattr(self, source_copied_key):
|
||||
self._copy_local_directory_to_container(container, local_source_path)
|
||||
if local_sources and not hasattr(self, source_copied_key):
|
||||
for index, source in enumerate(local_sources, start=1):
|
||||
source_path = source.get("source_path")
|
||||
if not source_path:
|
||||
continue
|
||||
|
||||
target_name = source.get("workspace_subdir")
|
||||
if not target_name:
|
||||
target_name = Path(source_path).name or f"target_{index}"
|
||||
|
||||
self._copy_local_directory_to_container(container, source_path, target_name)
|
||||
setattr(self, source_copied_key, True)
|
||||
|
||||
container_id = container.id
|
||||
|
||||
@@ -13,7 +13,10 @@ class SandboxInfo(TypedDict):
|
||||
class AbstractRuntime(ABC):
|
||||
@abstractmethod
|
||||
async def create_sandbox(
|
||||
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None
|
||||
self,
|
||||
agent_id: str,
|
||||
existing_token: str | None = None,
|
||||
local_sources: list[dict[str, str]] | None = None,
|
||||
) -> SandboxInfo:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -44,8 +44,7 @@ class Tracer:
|
||||
"run_name": self.run_name,
|
||||
"start_time": self.start_time,
|
||||
"end_time": None,
|
||||
"target": None,
|
||||
"scan_type": None,
|
||||
"targets": [],
|
||||
"status": "running",
|
||||
}
|
||||
self._run_dir: Path | None = None
|
||||
@@ -193,8 +192,7 @@ class Tracer:
|
||||
self.scan_config = config
|
||||
self.run_metadata.update(
|
||||
{
|
||||
"target": config.get("target", {}),
|
||||
"scan_type": config.get("scan_type", "general"),
|
||||
"targets": config.get("targets", []),
|
||||
"user_instructions": config.get("user_instructions", ""),
|
||||
"max_iterations": config.get("max_iterations", 200),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user