feat: implement multi-target scanning
This commit is contained in:
@@ -19,55 +19,64 @@ class StrixAgent(BaseAgent):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
async def execute_scan(self, scan_config: dict[str, Any]) -> dict[str, Any]:
|
async def execute_scan(self, scan_config: dict[str, Any]) -> dict[str, Any]:
|
||||||
scan_type = scan_config.get("scan_type", "general")
|
|
||||||
target = scan_config.get("target", {})
|
|
||||||
user_instructions = scan_config.get("user_instructions", "")
|
user_instructions = scan_config.get("user_instructions", "")
|
||||||
|
targets = scan_config.get("targets", [])
|
||||||
|
|
||||||
|
repositories = []
|
||||||
|
local_code = []
|
||||||
|
urls = []
|
||||||
|
|
||||||
|
for target in targets:
|
||||||
|
target_type = target["type"]
|
||||||
|
details = target["details"]
|
||||||
|
workspace_subdir = details.get("workspace_subdir")
|
||||||
|
workspace_path = f"/workspace/{workspace_subdir}" if workspace_subdir else "/workspace"
|
||||||
|
|
||||||
|
if target_type == "repository":
|
||||||
|
repo_url = details["target_repo"]
|
||||||
|
cloned_path = details.get("cloned_repo_path")
|
||||||
|
repositories.append(
|
||||||
|
{
|
||||||
|
"url": repo_url,
|
||||||
|
"workspace_path": workspace_path if cloned_path else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
elif target_type == "local_code":
|
||||||
|
original_path = details.get("target_path", "unknown")
|
||||||
|
local_code.append(
|
||||||
|
{
|
||||||
|
"path": original_path,
|
||||||
|
"workspace_path": workspace_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
elif target_type == "web_application":
|
||||||
|
urls.append(details["target_url"])
|
||||||
|
|
||||||
task_parts = []
|
task_parts = []
|
||||||
|
|
||||||
if scan_type == "repository":
|
if repositories:
|
||||||
repo_url = target["target_repo"]
|
task_parts.append("\n\nRepositories:")
|
||||||
cloned_path = target.get("cloned_repo_path")
|
for repo in repositories:
|
||||||
|
if repo["workspace_path"]:
|
||||||
if cloned_path:
|
task_parts.append(f"- {repo['url']} (available at: {repo['workspace_path']})")
|
||||||
workspace_path = "/workspace"
|
|
||||||
task_parts.append(
|
|
||||||
f"Perform a security assessment of the Git repository: {repo_url}. "
|
|
||||||
f"The repository has been cloned from '{repo_url}' to '{cloned_path}' "
|
|
||||||
f"(host path) and then copied to '{workspace_path}' in your environment."
|
|
||||||
f"Analyze the codebase at: {workspace_path}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
task_parts.append(
|
task_parts.append(f"- {repo['url']}")
|
||||||
f"Perform a security assessment of the Git repository: {repo_url}"
|
|
||||||
|
if local_code:
|
||||||
|
task_parts.append("\n\nLocal Codebases:")
|
||||||
|
task_parts.extend(
|
||||||
|
f"- {code['path']} (available at: {code['workspace_path']})" for code in local_code
|
||||||
)
|
)
|
||||||
|
|
||||||
elif scan_type == "web_application":
|
if urls:
|
||||||
task_parts.append(
|
task_parts.append("\n\nURLs:")
|
||||||
f"Perform a security assessment of the web application: {target['target_url']}"
|
task_parts.extend(f"- {url}" for url in urls)
|
||||||
)
|
|
||||||
|
|
||||||
elif scan_type == "local_code":
|
|
||||||
original_path = target.get("target_path", "unknown")
|
|
||||||
workspace_path = "/workspace"
|
|
||||||
task_parts.append(
|
|
||||||
f"Perform a security assessment of the local codebase. "
|
|
||||||
f"The code from '{original_path}' (user host path) has been copied to "
|
|
||||||
f"'{workspace_path}' in your environment. "
|
|
||||||
f"Analyze the codebase at: {workspace_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
task_parts.append(
|
|
||||||
f"Perform a general security assessment of: {next(iter(target.values()))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
task_description = " ".join(task_parts)
|
task_description = " ".join(task_parts)
|
||||||
|
|
||||||
if user_instructions:
|
if user_instructions:
|
||||||
task_description += (
|
task_description += f"\n\nSpecial instructions: {user_instructions}"
|
||||||
f"\n\nSpecial instructions from the system that must be followed: "
|
|
||||||
f"{user_instructions}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.agent_loop(task=task_description)
|
return await self.agent_loop(task=task_description)
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
def __init__(self, config: dict[str, Any]):
|
def __init__(self, config: dict[str, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.local_source_path = config.get("local_source_path")
|
self.local_sources = config.get("local_sources", [])
|
||||||
self.non_interactive = config.get("non_interactive", False)
|
self.non_interactive = config.get("non_interactive", False)
|
||||||
|
|
||||||
if "max_iterations" in config:
|
if "max_iterations" in config:
|
||||||
@@ -317,7 +317,7 @@ class BaseAgent(metaclass=AgentMeta):
|
|||||||
|
|
||||||
runtime = get_runtime()
|
runtime = get_runtime()
|
||||||
sandbox_info = await runtime.create_sandbox(
|
sandbox_info = await runtime.create_sandbox(
|
||||||
self.state.agent_id, self.state.sandbox_token, self.local_source_path
|
self.state.agent_id, self.state.sandbox_token, self.local_sources
|
||||||
)
|
)
|
||||||
self.state.sandbox_id = sandbox_info["workspace_id"]
|
self.state.sandbox_id = sandbox_info["workspace_id"]
|
||||||
self.state.sandbox_token = sandbox_info["auth_token"]
|
self.state.sandbox_token = sandbox_info["auth_token"]
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from strix.agents.StrixAgent import StrixAgent
|
|||||||
from strix.llm.config import LLMConfig
|
from strix.llm.config import LLMConfig
|
||||||
from strix.telemetry.tracer import Tracer, set_global_tracer
|
from strix.telemetry.tracer import Tracer, set_global_tracer
|
||||||
|
|
||||||
|
from .utils import get_severity_color
|
||||||
|
|
||||||
|
|
||||||
async def run_cli(args: Any) -> None: # noqa: PLR0915
|
async def run_cli(args: Any) -> None: # noqa: PLR0915
|
||||||
console = Console()
|
console = Console()
|
||||||
@@ -19,15 +21,18 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
|||||||
start_text.append("🦉 ", style="bold white")
|
start_text.append("🦉 ", style="bold white")
|
||||||
start_text.append("STRIX CYBERSECURITY AGENT", style="bold green")
|
start_text.append("STRIX CYBERSECURITY AGENT", style="bold green")
|
||||||
|
|
||||||
target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target
|
|
||||||
target_text = Text()
|
target_text = Text()
|
||||||
|
if len(args.targets_info) == 1:
|
||||||
target_text.append("🎯 Target: ", style="bold cyan")
|
target_text.append("🎯 Target: ", style="bold cyan")
|
||||||
target_text.append(str(target_value), style="bold white")
|
target_text.append(args.targets_info[0]["original"], style="bold white")
|
||||||
|
else:
|
||||||
instructions_text = Text()
|
target_text.append("🎯 Targets: ", style="bold cyan")
|
||||||
if args.instruction:
|
target_text.append(f"{len(args.targets_info)} targets\n", style="bold white")
|
||||||
instructions_text.append("📋 Instructions: ", style="bold cyan")
|
for i, target_info in enumerate(args.targets_info):
|
||||||
instructions_text.append(args.instruction, style="white")
|
target_text.append(" • ", style="dim white")
|
||||||
|
target_text.append(target_info["original"], style="white")
|
||||||
|
if i < len(args.targets_info) - 1:
|
||||||
|
target_text.append("\n")
|
||||||
|
|
||||||
results_text = Text()
|
results_text = Text()
|
||||||
results_text.append("📊 Results will be saved to: ", style="bold cyan")
|
results_text.append("📊 Results will be saved to: ", style="bold cyan")
|
||||||
@@ -44,8 +49,6 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
|||||||
start_text,
|
start_text,
|
||||||
"\n\n",
|
"\n\n",
|
||||||
target_text,
|
target_text,
|
||||||
"\n" if args.instruction else "",
|
|
||||||
instructions_text if args.instruction else "",
|
|
||||||
"\n",
|
"\n",
|
||||||
results_text,
|
results_text,
|
||||||
note_text,
|
note_text,
|
||||||
@@ -62,8 +65,7 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
|||||||
|
|
||||||
scan_config = {
|
scan_config = {
|
||||||
"scan_id": args.run_name,
|
"scan_id": args.run_name,
|
||||||
"scan_type": args.target_type,
|
"targets": args.targets_info,
|
||||||
"target": args.target_dict,
|
|
||||||
"user_instructions": args.instruction or "",
|
"user_instructions": args.instruction or "",
|
||||||
"run_name": args.run_name,
|
"run_name": args.run_name,
|
||||||
}
|
}
|
||||||
@@ -75,23 +77,14 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
|||||||
"non_interactive": True,
|
"non_interactive": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.target_type == "local_code" and "target_path" in args.target_dict:
|
if getattr(args, "local_sources", None):
|
||||||
agent_config["local_source_path"] = args.target_dict["target_path"]
|
agent_config["local_sources"] = args.local_sources
|
||||||
elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict:
|
|
||||||
agent_config["local_source_path"] = args.target_dict["cloned_repo_path"]
|
|
||||||
|
|
||||||
tracer = Tracer(args.run_name)
|
tracer = Tracer(args.run_name)
|
||||||
tracer.set_scan_config(scan_config)
|
tracer.set_scan_config(scan_config)
|
||||||
|
|
||||||
def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None:
|
def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None:
|
||||||
severity_colors = {
|
severity_color = get_severity_color(severity.lower())
|
||||||
"critical": "#dc2626",
|
|
||||||
"high": "#ea580c",
|
|
||||||
"medium": "#d97706",
|
|
||||||
"low": "#65a30d",
|
|
||||||
"info": "#0284c7",
|
|
||||||
}
|
|
||||||
severity_color = severity_colors.get(severity.lower(), "#6b7280")
|
|
||||||
|
|
||||||
vuln_text = Text()
|
vuln_text = Text()
|
||||||
vuln_text.append("🐞 ", style="bold red")
|
vuln_text.append("🐞 ", style="bold red")
|
||||||
|
|||||||
@@ -7,16 +7,10 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import secrets
|
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import docker
|
|
||||||
import litellm
|
import litellm
|
||||||
from docker.errors import DockerException
|
from docker.errors import DockerException
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@@ -25,6 +19,19 @@ from rich.text import Text
|
|||||||
|
|
||||||
from strix.interface.cli import run_cli
|
from strix.interface.cli import run_cli
|
||||||
from strix.interface.tui import run_tui
|
from strix.interface.tui import run_tui
|
||||||
|
from strix.interface.utils import (
|
||||||
|
assign_workspace_subdirs,
|
||||||
|
build_llm_stats_text,
|
||||||
|
build_stats_text,
|
||||||
|
check_docker_connection,
|
||||||
|
clone_repository,
|
||||||
|
collect_local_sources,
|
||||||
|
generate_run_name,
|
||||||
|
image_exists,
|
||||||
|
infer_target_type,
|
||||||
|
process_pull_line,
|
||||||
|
validate_llm_response,
|
||||||
|
)
|
||||||
from strix.runtime.docker_runtime import STRIX_IMAGE
|
from strix.runtime.docker_runtime import STRIX_IMAGE
|
||||||
from strix.telemetry.tracer import get_global_tracer
|
from strix.telemetry.tracer import get_global_tracer
|
||||||
|
|
||||||
@@ -32,15 +39,6 @@ from strix.telemetry.tracer import get_global_tracer
|
|||||||
logging.getLogger().setLevel(logging.ERROR)
|
logging.getLogger().setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
def format_token_count(count: float) -> str:
|
|
||||||
count = int(count)
|
|
||||||
if count >= 1_000_000:
|
|
||||||
return f"{count / 1_000_000:.1f}M"
|
|
||||||
if count >= 1_000:
|
|
||||||
return f"{count / 1_000:.1f}K"
|
|
||||||
return str(count)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_environment() -> None: # noqa: PLR0912, PLR0915
|
def validate_environment() -> None: # noqa: PLR0912, PLR0915
|
||||||
console = Console()
|
console = Console()
|
||||||
missing_required_vars = []
|
missing_required_vars = []
|
||||||
@@ -163,11 +161,6 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def _validate_llm_response(response: Any) -> None:
|
|
||||||
if not response or not response.choices or not response.choices[0].message.content:
|
|
||||||
raise RuntimeError("Invalid response from LLM")
|
|
||||||
|
|
||||||
|
|
||||||
def check_docker_installed() -> None:
|
def check_docker_installed() -> None:
|
||||||
if shutil.which("docker") is None:
|
if shutil.which("docker") is None:
|
||||||
console = Console()
|
console = Console()
|
||||||
@@ -220,7 +213,7 @@ async def warm_up_llm() -> None:
|
|||||||
messages=test_messages,
|
messages=test_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
_validate_llm_response(response)
|
validate_llm_response(response)
|
||||||
|
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e: # noqa: BLE001
|
||||||
error_text = Text()
|
error_text = Text()
|
||||||
@@ -245,141 +238,6 @@ async def warm_up_llm() -> None:
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def generate_run_name() -> str:
|
|
||||||
# fmt: off
|
|
||||||
adjectives = [
|
|
||||||
"stealthy", "sneaky", "crafty", "elite", "phantom", "shadow", "silent",
|
|
||||||
"rogue", "covert", "ninja", "ghost", "cyber", "digital", "binary",
|
|
||||||
"encrypted", "obfuscated", "masked", "cloaked", "invisible", "anonymous"
|
|
||||||
]
|
|
||||||
nouns = [
|
|
||||||
"exploit", "payload", "backdoor", "rootkit", "keylogger", "botnet", "trojan",
|
|
||||||
"worm", "virus", "packet", "buffer", "shell", "daemon", "spider", "crawler",
|
|
||||||
"scanner", "sniffer", "honeypot", "firewall", "breach"
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
adj = secrets.choice(adjectives)
|
|
||||||
noun = secrets.choice(nouns)
|
|
||||||
number = secrets.randbelow(900) + 100
|
|
||||||
return f"{adj}-{noun}-{number}"
|
|
||||||
|
|
||||||
|
|
||||||
def clone_repository(repo_url: str, run_name: str) -> str:
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
git_executable = shutil.which("git")
|
|
||||||
if git_executable is None:
|
|
||||||
raise FileNotFoundError("Git executable not found in PATH")
|
|
||||||
|
|
||||||
temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name
|
|
||||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name
|
|
||||||
|
|
||||||
clone_path = temp_dir / repo_name
|
|
||||||
|
|
||||||
if clone_path.exists():
|
|
||||||
shutil.rmtree(clone_path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with console.status(f"[bold cyan]Cloning repository {repo_name}...", spinner="dots"):
|
|
||||||
subprocess.run( # noqa: S603
|
|
||||||
[
|
|
||||||
git_executable,
|
|
||||||
"clone",
|
|
||||||
repo_url,
|
|
||||||
str(clone_path),
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return str(clone_path.absolute())
|
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
error_text = Text()
|
|
||||||
error_text.append("❌ ", style="bold red")
|
|
||||||
error_text.append("REPOSITORY CLONE FAILED", style="bold red")
|
|
||||||
error_text.append("\n\n", style="white")
|
|
||||||
error_text.append(f"Could not clone repository: {repo_url}\n", style="white")
|
|
||||||
error_text.append(
|
|
||||||
f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red"
|
|
||||||
)
|
|
||||||
|
|
||||||
panel = Panel(
|
|
||||||
error_text,
|
|
||||||
title="[bold red]🛡️ STRIX CLONE ERROR",
|
|
||||||
title_align="center",
|
|
||||||
border_style="red",
|
|
||||||
padding=(1, 2),
|
|
||||||
)
|
|
||||||
console.print("\n")
|
|
||||||
console.print(panel)
|
|
||||||
console.print()
|
|
||||||
sys.exit(1)
|
|
||||||
except FileNotFoundError:
|
|
||||||
error_text = Text()
|
|
||||||
error_text.append("❌ ", style="bold red")
|
|
||||||
error_text.append("GIT NOT FOUND", style="bold red")
|
|
||||||
error_text.append("\n\n", style="white")
|
|
||||||
error_text.append("Git is not installed or not available in PATH.\n", style="white")
|
|
||||||
error_text.append("Please install Git to clone repositories.\n", style="white")
|
|
||||||
|
|
||||||
panel = Panel(
|
|
||||||
error_text,
|
|
||||||
title="[bold red]🛡️ STRIX CLONE ERROR",
|
|
||||||
title_align="center",
|
|
||||||
border_style="red",
|
|
||||||
padding=(1, 2),
|
|
||||||
)
|
|
||||||
console.print("\n")
|
|
||||||
console.print(panel)
|
|
||||||
console.print()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def infer_target_type(target: str) -> tuple[str, dict[str, str]]:
|
|
||||||
if not target or not isinstance(target, str):
|
|
||||||
raise ValueError("Target must be a non-empty string")
|
|
||||||
|
|
||||||
target = target.strip()
|
|
||||||
|
|
||||||
parsed = urlparse(target)
|
|
||||||
if parsed.scheme in ("http", "https"):
|
|
||||||
if any(
|
|
||||||
host in parsed.netloc.lower() for host in ["github.com", "gitlab.com", "bitbucket.org"]
|
|
||||||
):
|
|
||||||
return "repository", {"target_repo": target}
|
|
||||||
return "web_application", {"target_url": target}
|
|
||||||
|
|
||||||
path = Path(target)
|
|
||||||
try:
|
|
||||||
if path.exists():
|
|
||||||
if path.is_dir():
|
|
||||||
return "local_code", {"target_path": str(path.absolute())}
|
|
||||||
raise ValueError(f"Path exists but is not a directory: {target}")
|
|
||||||
except (OSError, RuntimeError) as e:
|
|
||||||
raise ValueError(f"Invalid path: {target} - {e!s}") from e
|
|
||||||
|
|
||||||
if target.startswith("git@") or target.endswith(".git"):
|
|
||||||
return "repository", {"target_repo": target}
|
|
||||||
|
|
||||||
if "." in target and "/" not in target and not target.startswith("."):
|
|
||||||
parts = target.split(".")
|
|
||||||
if len(parts) >= 2 and all(p and p.strip() for p in parts):
|
|
||||||
return "web_application", {"target_url": f"https://{target}"}
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid target: {target}\n"
|
|
||||||
"Target must be one of:\n"
|
|
||||||
"- A valid URL (http:// or https://)\n"
|
|
||||||
"- A Git repository URL (https://github.com/... or git@github.com:...)\n"
|
|
||||||
"- A local directory path\n"
|
|
||||||
"- A domain name (e.g., example.com)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments() -> argparse.Namespace:
|
def parse_arguments() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Strix Multi-Agent Cybersecurity Penetration Testing Tool",
|
description="Strix Multi-Agent Cybersecurity Penetration Testing Tool",
|
||||||
@@ -399,16 +257,23 @@ Examples:
|
|||||||
# Domain penetration test
|
# Domain penetration test
|
||||||
strix --target example.com
|
strix --target example.com
|
||||||
|
|
||||||
|
# Multiple targets (e.g., white-box testing with source and deployed app)
|
||||||
|
strix --target https://github.com/user/repo --target https://example.com
|
||||||
|
strix --target ./my-project --target https://staging.example.com --target https://prod.example.com
|
||||||
|
|
||||||
# Custom instructions
|
# Custom instructions
|
||||||
strix --target example.com --instruction "Focus on authentication vulnerabilities"
|
strix --target example.com --instruction "Focus on authentication vulnerabilities"
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
"-t",
|
||||||
"--target",
|
"--target",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Target to test (URL, repository, local directory path, or domain name)",
|
action="append",
|
||||||
|
help="Target to test (URL, repository, local directory path, or domain name). "
|
||||||
|
"Can be specified multiple times for multi-target scans.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--instruction",
|
"--instruction",
|
||||||
@@ -439,127 +304,53 @@ Examples:
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.targets_info = []
|
||||||
|
for target in args.target:
|
||||||
try:
|
try:
|
||||||
args.target_type, args.target_dict = infer_target_type(args.target)
|
target_type, target_dict = infer_target_type(target)
|
||||||
except ValueError as e:
|
|
||||||
parser.error(str(e))
|
if target_type == "local_code":
|
||||||
|
display_target = target_dict.get("target_path", target)
|
||||||
|
else:
|
||||||
|
display_target = target
|
||||||
|
|
||||||
|
args.targets_info.append(
|
||||||
|
{"type": target_type, "details": target_dict, "original": display_target}
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
parser.error(f"Invalid target '{target}'")
|
||||||
|
|
||||||
|
assign_workspace_subdirs(args.targets_info)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def _get_severity_color(severity: str) -> str:
|
|
||||||
severity_colors = {
|
|
||||||
"critical": "#dc2626",
|
|
||||||
"high": "#ea580c",
|
|
||||||
"medium": "#d97706",
|
|
||||||
"low": "#65a30d",
|
|
||||||
"info": "#0284c7",
|
|
||||||
}
|
|
||||||
return severity_colors.get(severity, "#6b7280")
|
|
||||||
|
|
||||||
|
|
||||||
def _build_stats_text(tracer: Any) -> Text:
|
|
||||||
stats_text = Text()
|
|
||||||
if not tracer:
|
|
||||||
return stats_text
|
|
||||||
|
|
||||||
vuln_count = len(tracer.vulnerability_reports)
|
|
||||||
tool_count = tracer.get_real_tool_count()
|
|
||||||
agent_count = len(tracer.agents)
|
|
||||||
|
|
||||||
if vuln_count > 0:
|
|
||||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
|
||||||
for report in tracer.vulnerability_reports:
|
|
||||||
severity = report.get("severity", "").lower()
|
|
||||||
if severity in severity_counts:
|
|
||||||
severity_counts[severity] += 1
|
|
||||||
|
|
||||||
stats_text.append("🔍 Vulnerabilities Found: ", style="bold red")
|
|
||||||
|
|
||||||
severity_parts = []
|
|
||||||
for severity in ["critical", "high", "medium", "low", "info"]:
|
|
||||||
count = severity_counts[severity]
|
|
||||||
if count > 0:
|
|
||||||
severity_color = _get_severity_color(severity)
|
|
||||||
severity_text = Text()
|
|
||||||
severity_text.append(f"{severity.upper()}: ", style=severity_color)
|
|
||||||
severity_text.append(str(count), style=f"bold {severity_color}")
|
|
||||||
severity_parts.append(severity_text)
|
|
||||||
|
|
||||||
for i, part in enumerate(severity_parts):
|
|
||||||
stats_text.append(part)
|
|
||||||
if i < len(severity_parts) - 1:
|
|
||||||
stats_text.append(" | ", style="dim white")
|
|
||||||
|
|
||||||
stats_text.append(" (Total: ", style="dim white")
|
|
||||||
stats_text.append(str(vuln_count), style="bold yellow")
|
|
||||||
stats_text.append(")", style="dim white")
|
|
||||||
stats_text.append("\n")
|
|
||||||
else:
|
|
||||||
stats_text.append("🔍 Vulnerabilities Found: ", style="bold green")
|
|
||||||
stats_text.append("0", style="bold white")
|
|
||||||
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
|
|
||||||
stats_text.append("\n")
|
|
||||||
|
|
||||||
stats_text.append("🤖 Agents Used: ", style="bold cyan")
|
|
||||||
stats_text.append(str(agent_count), style="bold white")
|
|
||||||
stats_text.append(" • ", style="dim white")
|
|
||||||
stats_text.append("🛠️ Tools Called: ", style="bold cyan")
|
|
||||||
stats_text.append(str(tool_count), style="bold white")
|
|
||||||
|
|
||||||
return stats_text
|
|
||||||
|
|
||||||
|
|
||||||
def _build_llm_stats_text(tracer: Any) -> Text:
|
|
||||||
llm_stats_text = Text()
|
|
||||||
if not tracer:
|
|
||||||
return llm_stats_text
|
|
||||||
|
|
||||||
llm_stats = tracer.get_total_llm_stats()
|
|
||||||
total_stats = llm_stats["total"]
|
|
||||||
|
|
||||||
if total_stats["requests"] > 0:
|
|
||||||
llm_stats_text.append("📥 Input Tokens: ", style="bold cyan")
|
|
||||||
llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
|
|
||||||
|
|
||||||
if total_stats["cached_tokens"] > 0:
|
|
||||||
llm_stats_text.append(" • ", style="dim white")
|
|
||||||
llm_stats_text.append("⚡ Cached: ", style="bold green")
|
|
||||||
llm_stats_text.append(
|
|
||||||
format_token_count(total_stats["cached_tokens"]), style="bold green"
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_stats_text.append(" • ", style="dim white")
|
|
||||||
llm_stats_text.append("📤 Output Tokens: ", style="bold cyan")
|
|
||||||
llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
|
|
||||||
|
|
||||||
if total_stats["cost"] > 0:
|
|
||||||
llm_stats_text.append(" • ", style="dim white")
|
|
||||||
llm_stats_text.append("💰 Total Cost: $", style="bold cyan")
|
|
||||||
llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow")
|
|
||||||
|
|
||||||
return llm_stats_text
|
|
||||||
|
|
||||||
|
|
||||||
def display_completion_message(args: argparse.Namespace, results_path: Path) -> None:
|
def display_completion_message(args: argparse.Namespace, results_path: Path) -> None:
|
||||||
console = Console()
|
console = Console()
|
||||||
tracer = get_global_tracer()
|
tracer = get_global_tracer()
|
||||||
|
|
||||||
target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target
|
|
||||||
|
|
||||||
completion_text = Text()
|
completion_text = Text()
|
||||||
completion_text.append("🦉 ", style="bold white")
|
completion_text.append("🦉 ", style="bold white")
|
||||||
completion_text.append("AGENT FINISHED", style="bold green")
|
completion_text.append("AGENT FINISHED", style="bold green")
|
||||||
completion_text.append(" • ", style="dim white")
|
completion_text.append(" • ", style="dim white")
|
||||||
completion_text.append("Penetration test completed", style="white")
|
completion_text.append("Penetration test completed", style="white")
|
||||||
|
|
||||||
stats_text = _build_stats_text(tracer)
|
stats_text = build_stats_text(tracer)
|
||||||
|
|
||||||
llm_stats_text = _build_llm_stats_text(tracer)
|
llm_stats_text = build_llm_stats_text(tracer)
|
||||||
|
|
||||||
target_text = Text()
|
target_text = Text()
|
||||||
|
if len(args.targets_info) == 1:
|
||||||
target_text.append("🎯 Target: ", style="bold cyan")
|
target_text.append("🎯 Target: ", style="bold cyan")
|
||||||
target_text.append(str(target_value), style="bold white")
|
target_text.append(args.targets_info[0]["original"], style="bold white")
|
||||||
|
else:
|
||||||
|
target_text.append("🎯 Targets: ", style="bold cyan")
|
||||||
|
target_text.append(f"{len(args.targets_info)} targets\n", style="bold white")
|
||||||
|
for i, target_info in enumerate(args.targets_info):
|
||||||
|
target_text.append(" • ", style="dim white")
|
||||||
|
target_text.append(target_info["original"], style="white")
|
||||||
|
if i < len(args.targets_info) - 1:
|
||||||
|
target_text.append("\n")
|
||||||
|
|
||||||
results_text = Text()
|
results_text = Text()
|
||||||
results_text.append("📊 Results Saved To: ", style="bold cyan")
|
results_text.append("📊 Results Saved To: ", style="bold cyan")
|
||||||
@@ -575,19 +366,19 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
|
|||||||
stats_text,
|
stats_text,
|
||||||
"\n",
|
"\n",
|
||||||
llm_stats_text,
|
llm_stats_text,
|
||||||
"\n",
|
"\n\n",
|
||||||
results_text,
|
results_text,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
panel_content = Text.assemble(
|
panel_content = Text.assemble(
|
||||||
completion_text, "\n\n", target_text, "\n", stats_text, "\n", results_text
|
completion_text, "\n\n", target_text, "\n", stats_text, "\n\n", results_text
|
||||||
)
|
)
|
||||||
elif llm_stats_text.plain:
|
elif llm_stats_text.plain:
|
||||||
panel_content = Text.assemble(
|
panel_content = Text.assemble(
|
||||||
completion_text, "\n\n", target_text, "\n", llm_stats_text, "\n", results_text
|
completion_text, "\n\n", target_text, "\n", llm_stats_text, "\n\n", results_text
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
panel_content = Text.assemble(completion_text, "\n\n", target_text, "\n", results_text)
|
panel_content = Text.assemble(completion_text, "\n\n", target_text, "\n\n", results_text)
|
||||||
|
|
||||||
panel = Panel(
|
panel = Panel(
|
||||||
panel_content,
|
panel_content,
|
||||||
@@ -602,86 +393,11 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
|
|||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
|
|
||||||
def _check_docker_connection() -> Any:
|
|
||||||
try:
|
|
||||||
return docker.from_env()
|
|
||||||
except DockerException:
|
|
||||||
console = Console()
|
|
||||||
error_text = Text()
|
|
||||||
error_text.append("❌ ", style="bold red")
|
|
||||||
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
|
|
||||||
error_text.append("\n\n", style="white")
|
|
||||||
error_text.append("Cannot connect to Docker daemon.\n", style="white")
|
|
||||||
error_text.append("Please ensure Docker is installed and running.\n\n", style="white")
|
|
||||||
error_text.append("Try running: ", style="dim white")
|
|
||||||
error_text.append("sudo systemctl start docker", style="dim cyan")
|
|
||||||
|
|
||||||
panel = Panel(
|
|
||||||
error_text,
|
|
||||||
title="[bold red]🛡️ STRIX STARTUP ERROR",
|
|
||||||
title_align="center",
|
|
||||||
border_style="red",
|
|
||||||
padding=(1, 2),
|
|
||||||
)
|
|
||||||
console.print("\n", panel, "\n")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def _image_exists(client: Any) -> bool:
|
|
||||||
try:
|
|
||||||
client.images.get(STRIX_IMAGE)
|
|
||||||
except docker.errors.ImageNotFound:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None:
|
|
||||||
if "Pull complete" in layer_status or "Already exists" in layer_status:
|
|
||||||
layers_info[layer_id] = "✓"
|
|
||||||
elif "Downloading" in layer_status:
|
|
||||||
layers_info[layer_id] = "↓"
|
|
||||||
elif "Extracting" in layer_status:
|
|
||||||
layers_info[layer_id] = "📦"
|
|
||||||
elif "Waiting" in layer_status:
|
|
||||||
layers_info[layer_id] = "⏳"
|
|
||||||
else:
|
|
||||||
layers_info[layer_id] = "•"
|
|
||||||
|
|
||||||
|
|
||||||
def _process_pull_line(
|
|
||||||
line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str
|
|
||||||
) -> str:
|
|
||||||
if "id" in line and "status" in line:
|
|
||||||
layer_id = line["id"]
|
|
||||||
_update_layer_status(layers_info, layer_id, line["status"])
|
|
||||||
|
|
||||||
completed = sum(1 for v in layers_info.values() if v == "✓")
|
|
||||||
total = len(layers_info)
|
|
||||||
|
|
||||||
if total > 0:
|
|
||||||
update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete"
|
|
||||||
if update_msg != last_update:
|
|
||||||
status.update(update_msg)
|
|
||||||
return update_msg
|
|
||||||
|
|
||||||
elif "status" in line and "id" not in line:
|
|
||||||
global_status = line["status"]
|
|
||||||
if "Pulling from" in global_status:
|
|
||||||
status.update("[bold cyan]Fetching image manifest...")
|
|
||||||
elif "Digest:" in global_status:
|
|
||||||
status.update("[bold cyan]Verifying image...")
|
|
||||||
elif "Status:" in global_status:
|
|
||||||
status.update("[bold cyan]Finalizing...")
|
|
||||||
|
|
||||||
return last_update
|
|
||||||
|
|
||||||
|
|
||||||
def pull_docker_image() -> None:
|
def pull_docker_image() -> None:
|
||||||
console = Console()
|
console = Console()
|
||||||
client = _check_docker_connection()
|
client = check_docker_connection()
|
||||||
|
|
||||||
if _image_exists(client):
|
if image_exists(client, STRIX_IMAGE):
|
||||||
return
|
return
|
||||||
|
|
||||||
console.print()
|
console.print()
|
||||||
@@ -695,7 +411,7 @@ def pull_docker_image() -> None:
|
|||||||
last_update = ""
|
last_update = ""
|
||||||
|
|
||||||
for line in client.api.pull(STRIX_IMAGE, stream=True, decode=True):
|
for line in client.api.pull(STRIX_IMAGE, stream=True, decode=True):
|
||||||
last_update = _process_pull_line(line, layers_info, status, last_update)
|
last_update = process_pull_line(line, layers_info, status, last_update)
|
||||||
|
|
||||||
except DockerException as e:
|
except DockerException as e:
|
||||||
console.print()
|
console.print()
|
||||||
@@ -738,11 +454,14 @@ def main() -> None:
|
|||||||
if not args.run_name:
|
if not args.run_name:
|
||||||
args.run_name = generate_run_name()
|
args.run_name = generate_run_name()
|
||||||
|
|
||||||
if args.target_type == "repository":
|
for target_info in args.targets_info:
|
||||||
repo_url = args.target_dict["target_repo"]
|
if target_info["type"] == "repository":
|
||||||
cloned_path = clone_repository(repo_url, args.run_name)
|
repo_url = target_info["details"]["target_repo"]
|
||||||
|
dest_name = target_info["details"].get("workspace_subdir")
|
||||||
|
cloned_path = clone_repository(repo_url, args.run_name, dest_name)
|
||||||
|
target_info["details"]["cloned_repo_path"] = cloned_path
|
||||||
|
|
||||||
args.target_dict["cloned_repo_path"] = cloned_path
|
args.local_sources = collect_local_sources(args.targets_info)
|
||||||
|
|
||||||
if args.non_interactive:
|
if args.non_interactive:
|
||||||
asyncio.run(run_cli(args))
|
asyncio.run(run_cli(args))
|
||||||
|
|||||||
@@ -16,23 +16,28 @@ class ScanStartInfoRenderer(BaseToolRenderer):
|
|||||||
args = tool_data.get("args", {})
|
args = tool_data.get("args", {})
|
||||||
status = tool_data.get("status", "unknown")
|
status = tool_data.get("status", "unknown")
|
||||||
|
|
||||||
target = args.get("target", {})
|
targets = args.get("targets", [])
|
||||||
|
|
||||||
target_display = cls._build_target_display(target)
|
if len(targets) == 1:
|
||||||
|
target_display = cls._build_single_target_display(targets[0])
|
||||||
content = f"🚀 Starting scan on {target_display}"
|
content = f"🚀 Starting penetration test on {target_display}"
|
||||||
|
elif len(targets) > 1:
|
||||||
|
content = f"🚀 Starting penetration test on {len(targets)} targets"
|
||||||
|
for target_info in targets:
|
||||||
|
target_display = cls._build_single_target_display(target_info)
|
||||||
|
content += f"\n • {target_display}"
|
||||||
|
else:
|
||||||
|
content = "🚀 Starting penetration test"
|
||||||
|
|
||||||
css_classes = cls.get_css_classes(status)
|
css_classes = cls.get_css_classes(status)
|
||||||
return Static(content, classes=css_classes)
|
return Static(content, classes=css_classes)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_target_display(cls, target: dict[str, Any]) -> str:
|
def _build_single_target_display(cls, target_info: dict[str, Any]) -> str:
|
||||||
if target_url := target.get("target_url"):
|
original = target_info.get("original")
|
||||||
return cls.escape_markup(str(target_url))
|
if original:
|
||||||
if target_repo := target.get("target_repo"):
|
return cls.escape_markup(str(original))
|
||||||
return cls.escape_markup(str(target_repo))
|
|
||||||
if target_path := target.get("target_path"):
|
|
||||||
return cls.escape_markup(str(target_path))
|
|
||||||
return "unknown target"
|
return "unknown target"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -312,8 +312,7 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]:
|
def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"scan_id": args.run_name,
|
"scan_id": args.run_name,
|
||||||
"scan_type": args.target_type,
|
"targets": args.targets_info,
|
||||||
"target": args.target_dict,
|
|
||||||
"user_instructions": args.instruction or "",
|
"user_instructions": args.instruction or "",
|
||||||
"run_name": args.run_name,
|
"run_name": args.run_name,
|
||||||
}
|
}
|
||||||
@@ -326,10 +325,8 @@ class StrixTUIApp(App): # type: ignore[misc]
|
|||||||
"max_iterations": 300,
|
"max_iterations": 300,
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.target_type == "local_code" and "target_path" in args.target_dict:
|
if getattr(args, "local_sources", None):
|
||||||
config["local_source_path"] = args.target_dict["target_path"]
|
config["local_sources"] = args.local_sources
|
||||||
elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict:
|
|
||||||
config["local_source_path"] = args.target_dict["cloned_repo_path"]
|
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|||||||
434
strix/interface/utils.py
Normal file
434
strix/interface/utils.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
import re
|
||||||
|
import secrets
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import docker
|
||||||
|
from docker.errors import DockerException, ImageNotFound
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
|
||||||
|
# Token formatting utilities
|
||||||
|
def format_token_count(count: float) -> str:
|
||||||
|
count = int(count)
|
||||||
|
if count >= 1_000_000:
|
||||||
|
return f"{count / 1_000_000:.1f}M"
|
||||||
|
if count >= 1_000:
|
||||||
|
return f"{count / 1_000:.1f}K"
|
||||||
|
return str(count)
|
||||||
|
|
||||||
|
|
||||||
|
# Display utilities
|
||||||
|
def get_severity_color(severity: str) -> str:
|
||||||
|
severity_colors = {
|
||||||
|
"critical": "#dc2626",
|
||||||
|
"high": "#ea580c",
|
||||||
|
"medium": "#d97706",
|
||||||
|
"low": "#65a30d",
|
||||||
|
"info": "#0284c7",
|
||||||
|
}
|
||||||
|
return severity_colors.get(severity, "#6b7280")
|
||||||
|
|
||||||
|
|
||||||
|
def build_stats_text(tracer: Any) -> Text:
|
||||||
|
stats_text = Text()
|
||||||
|
if not tracer:
|
||||||
|
return stats_text
|
||||||
|
|
||||||
|
vuln_count = len(tracer.vulnerability_reports)
|
||||||
|
tool_count = tracer.get_real_tool_count()
|
||||||
|
agent_count = len(tracer.agents)
|
||||||
|
|
||||||
|
if vuln_count > 0:
|
||||||
|
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||||
|
for report in tracer.vulnerability_reports:
|
||||||
|
severity = report.get("severity", "").lower()
|
||||||
|
if severity in severity_counts:
|
||||||
|
severity_counts[severity] += 1
|
||||||
|
|
||||||
|
stats_text.append("🔍 Vulnerabilities Found: ", style="bold red")
|
||||||
|
|
||||||
|
severity_parts = []
|
||||||
|
for severity in ["critical", "high", "medium", "low", "info"]:
|
||||||
|
count = severity_counts[severity]
|
||||||
|
if count > 0:
|
||||||
|
severity_color = get_severity_color(severity)
|
||||||
|
severity_text = Text()
|
||||||
|
severity_text.append(f"{severity.upper()}: ", style=severity_color)
|
||||||
|
severity_text.append(str(count), style=f"bold {severity_color}")
|
||||||
|
severity_parts.append(severity_text)
|
||||||
|
|
||||||
|
for i, part in enumerate(severity_parts):
|
||||||
|
stats_text.append(part)
|
||||||
|
if i < len(severity_parts) - 1:
|
||||||
|
stats_text.append(" | ", style="dim white")
|
||||||
|
|
||||||
|
stats_text.append(" (Total: ", style="dim white")
|
||||||
|
stats_text.append(str(vuln_count), style="bold yellow")
|
||||||
|
stats_text.append(")", style="dim white")
|
||||||
|
stats_text.append("\n")
|
||||||
|
else:
|
||||||
|
stats_text.append("🔍 Vulnerabilities Found: ", style="bold green")
|
||||||
|
stats_text.append("0", style="bold white")
|
||||||
|
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
|
||||||
|
stats_text.append("\n")
|
||||||
|
|
||||||
|
stats_text.append("🤖 Agents Used: ", style="bold cyan")
|
||||||
|
stats_text.append(str(agent_count), style="bold white")
|
||||||
|
stats_text.append(" • ", style="dim white")
|
||||||
|
stats_text.append("🛠️ Tools Called: ", style="bold cyan")
|
||||||
|
stats_text.append(str(tool_count), style="bold white")
|
||||||
|
|
||||||
|
return stats_text
|
||||||
|
|
||||||
|
|
||||||
|
def build_llm_stats_text(tracer: Any) -> Text:
|
||||||
|
llm_stats_text = Text()
|
||||||
|
if not tracer:
|
||||||
|
return llm_stats_text
|
||||||
|
|
||||||
|
llm_stats = tracer.get_total_llm_stats()
|
||||||
|
total_stats = llm_stats["total"]
|
||||||
|
|
||||||
|
if total_stats["requests"] > 0:
|
||||||
|
llm_stats_text.append("📥 Input Tokens: ", style="bold cyan")
|
||||||
|
llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
|
||||||
|
|
||||||
|
if total_stats["cached_tokens"] > 0:
|
||||||
|
llm_stats_text.append(" • ", style="dim white")
|
||||||
|
llm_stats_text.append("⚡ Cached: ", style="bold green")
|
||||||
|
llm_stats_text.append(
|
||||||
|
format_token_count(total_stats["cached_tokens"]), style="bold green"
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_stats_text.append(" • ", style="dim white")
|
||||||
|
llm_stats_text.append("📤 Output Tokens: ", style="bold cyan")
|
||||||
|
llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
|
||||||
|
|
||||||
|
if total_stats["cost"] > 0:
|
||||||
|
llm_stats_text.append(" • ", style="dim white")
|
||||||
|
llm_stats_text.append("💰 Total Cost: $", style="bold cyan")
|
||||||
|
llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow")
|
||||||
|
|
||||||
|
return llm_stats_text
|
||||||
|
|
||||||
|
|
||||||
|
# Name generation utilities
|
||||||
|
def generate_run_name() -> str:
|
||||||
|
# fmt: off
|
||||||
|
adjectives = [
|
||||||
|
"stealthy", "sneaky", "crafty", "elite", "phantom", "shadow", "silent",
|
||||||
|
"rogue", "covert", "ninja", "ghost", "cyber", "digital", "binary",
|
||||||
|
"encrypted", "obfuscated", "masked", "cloaked", "invisible", "anonymous"
|
||||||
|
]
|
||||||
|
nouns = [
|
||||||
|
"exploit", "payload", "backdoor", "rootkit", "keylogger", "botnet", "trojan",
|
||||||
|
"worm", "virus", "packet", "buffer", "shell", "daemon", "spider", "crawler",
|
||||||
|
"scanner", "sniffer", "honeypot", "firewall", "breach"
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
adj = secrets.choice(adjectives)
|
||||||
|
noun = secrets.choice(nouns)
|
||||||
|
number = secrets.randbelow(900) + 100
|
||||||
|
return f"{adj}-{noun}-{number}"
|
||||||
|
|
||||||
|
|
||||||
|
# Target processing utilities
|
||||||
|
def infer_target_type(target: str) -> tuple[str, dict[str, str]]:
|
||||||
|
if not target or not isinstance(target, str):
|
||||||
|
raise ValueError("Target must be a non-empty string")
|
||||||
|
|
||||||
|
target = target.strip()
|
||||||
|
|
||||||
|
lower_target = target.lower()
|
||||||
|
bare_repo_prefixes = (
|
||||||
|
"github.com/",
|
||||||
|
"www.github.com/",
|
||||||
|
"gitlab.com/",
|
||||||
|
"www.gitlab.com/",
|
||||||
|
"bitbucket.org/",
|
||||||
|
"www.bitbucket.org/",
|
||||||
|
)
|
||||||
|
if any(lower_target.startswith(p) for p in bare_repo_prefixes):
|
||||||
|
return "repository", {"target_repo": f"https://{target}"}
|
||||||
|
|
||||||
|
parsed = urlparse(target)
|
||||||
|
if parsed.scheme in ("http", "https"):
|
||||||
|
if any(
|
||||||
|
host in parsed.netloc.lower() for host in ["github.com", "gitlab.com", "bitbucket.org"]
|
||||||
|
):
|
||||||
|
return "repository", {"target_repo": target}
|
||||||
|
return "web_application", {"target_url": target}
|
||||||
|
|
||||||
|
path = Path(target).expanduser()
|
||||||
|
try:
|
||||||
|
if path.exists():
|
||||||
|
if path.is_dir():
|
||||||
|
resolved = path.resolve()
|
||||||
|
return "local_code", {"target_path": str(resolved)}
|
||||||
|
raise ValueError(f"Path exists but is not a directory: {target}")
|
||||||
|
except (OSError, RuntimeError) as e:
|
||||||
|
raise ValueError(f"Invalid path: {target} - {e!s}") from e
|
||||||
|
|
||||||
|
if target.startswith("git@") or target.endswith(".git"):
|
||||||
|
return "repository", {"target_repo": target}
|
||||||
|
|
||||||
|
if "." in target and "/" not in target and not target.startswith("."):
|
||||||
|
parts = target.split(".")
|
||||||
|
if len(parts) >= 2 and all(p and p.strip() for p in parts):
|
||||||
|
return "web_application", {"target_url": f"https://{target}"}
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid target: {target}\n"
|
||||||
|
"Target must be one of:\n"
|
||||||
|
"- A valid URL (http:// or https://)\n"
|
||||||
|
"- A Git repository URL (https://github.com/... or git@github.com:...)\n"
|
||||||
|
"- A local directory path\n"
|
||||||
|
"- A domain name (e.g., example.com)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_name(name: str) -> str:
|
||||||
|
sanitized = re.sub(r"[^A-Za-z0-9._-]", "-", name.strip())
|
||||||
|
return sanitized or "target"
|
||||||
|
|
||||||
|
|
||||||
|
def derive_repo_base_name(repo_url: str) -> str:
|
||||||
|
if repo_url.endswith("/"):
|
||||||
|
repo_url = repo_url[:-1]
|
||||||
|
|
||||||
|
if ":" in repo_url and repo_url.startswith("git@"):
|
||||||
|
path_part = repo_url.split(":", 1)[1]
|
||||||
|
else:
|
||||||
|
path_part = urlparse(repo_url).path or repo_url
|
||||||
|
|
||||||
|
candidate = path_part.split("/")[-1]
|
||||||
|
if candidate.endswith(".git"):
|
||||||
|
candidate = candidate[:-4]
|
||||||
|
|
||||||
|
return sanitize_name(candidate or "repository")
|
||||||
|
|
||||||
|
|
||||||
|
def derive_local_base_name(path_str: str) -> str:
|
||||||
|
try:
|
||||||
|
base = Path(path_str).resolve().name
|
||||||
|
except (OSError, RuntimeError):
|
||||||
|
base = Path(path_str).name
|
||||||
|
return sanitize_name(base or "workspace")
|
||||||
|
|
||||||
|
|
||||||
|
def assign_workspace_subdirs(targets_info: list[dict[str, Any]]) -> None:
|
||||||
|
name_counts: dict[str, int] = {}
|
||||||
|
|
||||||
|
for target in targets_info:
|
||||||
|
target_type = target["type"]
|
||||||
|
details = target["details"]
|
||||||
|
|
||||||
|
base_name: str | None = None
|
||||||
|
if target_type == "repository":
|
||||||
|
base_name = derive_repo_base_name(details["target_repo"])
|
||||||
|
elif target_type == "local_code":
|
||||||
|
base_name = derive_local_base_name(details.get("target_path", "local"))
|
||||||
|
|
||||||
|
if base_name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
count = name_counts.get(base_name, 0) + 1
|
||||||
|
name_counts[base_name] = count
|
||||||
|
|
||||||
|
workspace_subdir = base_name if count == 1 else f"{base_name}-{count}"
|
||||||
|
|
||||||
|
details["workspace_subdir"] = workspace_subdir
|
||||||
|
|
||||||
|
|
||||||
|
def collect_local_sources(targets_info: list[dict[str, Any]]) -> list[dict[str, str]]:
|
||||||
|
local_sources: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
for target_info in targets_info:
|
||||||
|
details = target_info["details"]
|
||||||
|
workspace_subdir = details.get("workspace_subdir")
|
||||||
|
|
||||||
|
if target_info["type"] == "local_code" and "target_path" in details:
|
||||||
|
local_sources.append(
|
||||||
|
{
|
||||||
|
"source_path": details["target_path"],
|
||||||
|
"workspace_subdir": workspace_subdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
elif target_info["type"] == "repository" and "cloned_repo_path" in details:
|
||||||
|
local_sources.append(
|
||||||
|
{
|
||||||
|
"source_path": details["cloned_repo_path"],
|
||||||
|
"workspace_subdir": workspace_subdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return local_sources
|
||||||
|
|
||||||
|
|
||||||
|
# Repository utilities
|
||||||
|
def clone_repository(repo_url: str, run_name: str, dest_name: str | None = None) -> str:
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
git_executable = shutil.which("git")
|
||||||
|
if git_executable is None:
|
||||||
|
raise FileNotFoundError("Git executable not found in PATH")
|
||||||
|
|
||||||
|
temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name
|
||||||
|
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if dest_name:
|
||||||
|
repo_name = dest_name
|
||||||
|
else:
|
||||||
|
repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name
|
||||||
|
|
||||||
|
clone_path = temp_dir / repo_name
|
||||||
|
|
||||||
|
if clone_path.exists():
|
||||||
|
shutil.rmtree(clone_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with console.status(f"[bold cyan]Cloning repository {repo_url}...", spinner="dots"):
|
||||||
|
subprocess.run( # noqa: S603
|
||||||
|
[
|
||||||
|
git_executable,
|
||||||
|
"clone",
|
||||||
|
repo_url,
|
||||||
|
str(clone_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return str(clone_path.absolute())
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
error_text = Text()
|
||||||
|
error_text.append("❌ ", style="bold red")
|
||||||
|
error_text.append("REPOSITORY CLONE FAILED", style="bold red")
|
||||||
|
error_text.append("\n\n", style="white")
|
||||||
|
error_text.append(f"Could not clone repository: {repo_url}\n", style="white")
|
||||||
|
error_text.append(
|
||||||
|
f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red"
|
||||||
|
)
|
||||||
|
|
||||||
|
panel = Panel(
|
||||||
|
error_text,
|
||||||
|
title="[bold red]🛡️ STRIX CLONE ERROR",
|
||||||
|
title_align="center",
|
||||||
|
border_style="red",
|
||||||
|
padding=(1, 2),
|
||||||
|
)
|
||||||
|
console.print("\n")
|
||||||
|
console.print(panel)
|
||||||
|
console.print()
|
||||||
|
raise
|
||||||
|
except FileNotFoundError:
|
||||||
|
error_text = Text()
|
||||||
|
error_text.append("❌ ", style="bold red")
|
||||||
|
error_text.append("GIT NOT FOUND", style="bold red")
|
||||||
|
error_text.append("\n\n", style="white")
|
||||||
|
error_text.append("Git is not installed or not available in PATH.\n", style="white")
|
||||||
|
error_text.append("Please install Git to clone repositories.\n", style="white")
|
||||||
|
|
||||||
|
panel = Panel(
|
||||||
|
error_text,
|
||||||
|
title="[bold red]🛡️ STRIX CLONE ERROR",
|
||||||
|
title_align="center",
|
||||||
|
border_style="red",
|
||||||
|
padding=(1, 2),
|
||||||
|
)
|
||||||
|
console.print("\n")
|
||||||
|
console.print(panel)
|
||||||
|
console.print()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Docker utilities
|
||||||
|
def check_docker_connection() -> Any:
|
||||||
|
try:
|
||||||
|
return docker.from_env()
|
||||||
|
except DockerException:
|
||||||
|
console = Console()
|
||||||
|
error_text = Text()
|
||||||
|
error_text.append("❌ ", style="bold red")
|
||||||
|
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
|
||||||
|
error_text.append("\n\n", style="white")
|
||||||
|
error_text.append("Cannot connect to Docker daemon.\n", style="white")
|
||||||
|
error_text.append("Please ensure Docker is installed and running.\n\n", style="white")
|
||||||
|
error_text.append("Try running: ", style="dim white")
|
||||||
|
error_text.append("sudo systemctl start docker", style="dim cyan")
|
||||||
|
|
||||||
|
panel = Panel(
|
||||||
|
error_text,
|
||||||
|
title="[bold red]🛡️ STRIX STARTUP ERROR",
|
||||||
|
title_align="center",
|
||||||
|
border_style="red",
|
||||||
|
padding=(1, 2),
|
||||||
|
)
|
||||||
|
console.print("\n", panel, "\n")
|
||||||
|
raise RuntimeError("Docker not available") from None
|
||||||
|
|
||||||
|
|
||||||
|
def image_exists(client: Any, image_name: str) -> bool:
|
||||||
|
try:
|
||||||
|
client.images.get(image_name)
|
||||||
|
except ImageNotFound:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None:
|
||||||
|
if "Pull complete" in layer_status or "Already exists" in layer_status:
|
||||||
|
layers_info[layer_id] = "✓"
|
||||||
|
elif "Downloading" in layer_status:
|
||||||
|
layers_info[layer_id] = "↓"
|
||||||
|
elif "Extracting" in layer_status:
|
||||||
|
layers_info[layer_id] = "📦"
|
||||||
|
elif "Waiting" in layer_status:
|
||||||
|
layers_info[layer_id] = "⏳"
|
||||||
|
else:
|
||||||
|
layers_info[layer_id] = "•"
|
||||||
|
|
||||||
|
|
||||||
|
def process_pull_line(
|
||||||
|
line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str
|
||||||
|
) -> str:
|
||||||
|
if "id" in line and "status" in line:
|
||||||
|
layer_id = line["id"]
|
||||||
|
update_layer_status(layers_info, layer_id, line["status"])
|
||||||
|
|
||||||
|
completed = sum(1 for v in layers_info.values() if v == "✓")
|
||||||
|
total = len(layers_info)
|
||||||
|
|
||||||
|
if total > 0:
|
||||||
|
update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete"
|
||||||
|
if update_msg != last_update:
|
||||||
|
status.update(update_msg)
|
||||||
|
return update_msg
|
||||||
|
|
||||||
|
elif "status" in line and "id" not in line:
|
||||||
|
global_status = line["status"]
|
||||||
|
if "Pulling from" in global_status:
|
||||||
|
status.update("[bold cyan]Fetching image manifest...")
|
||||||
|
elif "Digest:" in global_status:
|
||||||
|
status.update("[bold cyan]Verifying image...")
|
||||||
|
elif "Status:" in global_status:
|
||||||
|
status.update("[bold cyan]Finalizing...")
|
||||||
|
|
||||||
|
return last_update
|
||||||
|
|
||||||
|
|
||||||
|
# LLM utilities
|
||||||
|
def validate_llm_response(response: Any) -> None:
|
||||||
|
if not response or not response.choices or not response.choices[0].message.content:
|
||||||
|
raise RuntimeError("Invalid response from LLM")
|
||||||
@@ -250,7 +250,9 @@ class DockerRuntime(AbstractRuntime):
|
|||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
def _copy_local_directory_to_container(self, container: Container, local_path: str) -> None:
|
def _copy_local_directory_to_container(
|
||||||
|
self, container: Container, local_path: str, target_name: str | None = None
|
||||||
|
) -> None:
|
||||||
import tarfile
|
import tarfile
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@@ -260,13 +262,20 @@ class DockerRuntime(AbstractRuntime):
|
|||||||
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
|
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if target_name:
|
||||||
|
logger.info(
|
||||||
|
f"Copying local directory {local_path_obj} to container at "
|
||||||
|
f"/workspace/{target_name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
logger.info(f"Copying local directory {local_path_obj} to container")
|
logger.info(f"Copying local directory {local_path_obj} to container")
|
||||||
|
|
||||||
tar_buffer = BytesIO()
|
tar_buffer = BytesIO()
|
||||||
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
|
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
|
||||||
for item in local_path_obj.rglob("*"):
|
for item in local_path_obj.rglob("*"):
|
||||||
if item.is_file():
|
if item.is_file():
|
||||||
arcname = item.relative_to(local_path_obj)
|
rel_path = item.relative_to(local_path_obj)
|
||||||
|
arcname = Path(target_name) / rel_path if target_name else rel_path
|
||||||
tar.add(item, arcname=arcname)
|
tar.add(item, arcname=arcname)
|
||||||
|
|
||||||
tar_buffer.seek(0)
|
tar_buffer.seek(0)
|
||||||
@@ -283,14 +292,26 @@ class DockerRuntime(AbstractRuntime):
|
|||||||
logger.exception("Failed to copy local directory to container")
|
logger.exception("Failed to copy local directory to container")
|
||||||
|
|
||||||
async def create_sandbox(
|
async def create_sandbox(
|
||||||
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
existing_token: str | None = None,
|
||||||
|
local_sources: list[dict[str, str]] | None = None,
|
||||||
) -> SandboxInfo:
|
) -> SandboxInfo:
|
||||||
scan_id = self._get_scan_id(agent_id)
|
scan_id = self._get_scan_id(agent_id)
|
||||||
container = self._get_or_create_scan_container(scan_id)
|
container = self._get_or_create_scan_container(scan_id)
|
||||||
|
|
||||||
source_copied_key = f"_source_copied_{scan_id}"
|
source_copied_key = f"_source_copied_{scan_id}"
|
||||||
if local_source_path and not hasattr(self, source_copied_key):
|
if local_sources and not hasattr(self, source_copied_key):
|
||||||
self._copy_local_directory_to_container(container, local_source_path)
|
for index, source in enumerate(local_sources, start=1):
|
||||||
|
source_path = source.get("source_path")
|
||||||
|
if not source_path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_name = source.get("workspace_subdir")
|
||||||
|
if not target_name:
|
||||||
|
target_name = Path(source_path).name or f"target_{index}"
|
||||||
|
|
||||||
|
self._copy_local_directory_to_container(container, source_path, target_name)
|
||||||
setattr(self, source_copied_key, True)
|
setattr(self, source_copied_key, True)
|
||||||
|
|
||||||
container_id = container.id
|
container_id = container.id
|
||||||
|
|||||||
@@ -13,7 +13,10 @@ class SandboxInfo(TypedDict):
|
|||||||
class AbstractRuntime(ABC):
|
class AbstractRuntime(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def create_sandbox(
|
async def create_sandbox(
|
||||||
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
existing_token: str | None = None,
|
||||||
|
local_sources: list[dict[str, str]] | None = None,
|
||||||
) -> SandboxInfo:
|
) -> SandboxInfo:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -44,8 +44,7 @@ class Tracer:
|
|||||||
"run_name": self.run_name,
|
"run_name": self.run_name,
|
||||||
"start_time": self.start_time,
|
"start_time": self.start_time,
|
||||||
"end_time": None,
|
"end_time": None,
|
||||||
"target": None,
|
"targets": [],
|
||||||
"scan_type": None,
|
|
||||||
"status": "running",
|
"status": "running",
|
||||||
}
|
}
|
||||||
self._run_dir: Path | None = None
|
self._run_dir: Path | None = None
|
||||||
@@ -193,8 +192,7 @@ class Tracer:
|
|||||||
self.scan_config = config
|
self.scan_config = config
|
||||||
self.run_metadata.update(
|
self.run_metadata.update(
|
||||||
{
|
{
|
||||||
"target": config.get("target", {}),
|
"targets": config.get("targets", []),
|
||||||
"scan_type": config.get("scan_type", "general"),
|
|
||||||
"user_instructions": config.get("user_instructions", ""),
|
"user_instructions": config.get("user_instructions", ""),
|
||||||
"max_iterations": config.get("max_iterations", 200),
|
"max_iterations": config.get("max_iterations", 200),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user