feat: implement multi-target scanning

This commit is contained in:
Ahmed Allam
2025-11-01 01:25:07 +02:00
committed by Ahmed Allam
parent deee85d547
commit 738fdc2d49
10 changed files with 619 additions and 440 deletions

View File

@@ -19,55 +19,64 @@ class StrixAgent(BaseAgent):
super().__init__(config) super().__init__(config)
async def execute_scan(self, scan_config: dict[str, Any]) -> dict[str, Any]: async def execute_scan(self, scan_config: dict[str, Any]) -> dict[str, Any]:
scan_type = scan_config.get("scan_type", "general")
target = scan_config.get("target", {})
user_instructions = scan_config.get("user_instructions", "") user_instructions = scan_config.get("user_instructions", "")
targets = scan_config.get("targets", [])
repositories = []
local_code = []
urls = []
for target in targets:
target_type = target["type"]
details = target["details"]
workspace_subdir = details.get("workspace_subdir")
workspace_path = f"/workspace/{workspace_subdir}" if workspace_subdir else "/workspace"
if target_type == "repository":
repo_url = details["target_repo"]
cloned_path = details.get("cloned_repo_path")
repositories.append(
{
"url": repo_url,
"workspace_path": workspace_path if cloned_path else None,
}
)
elif target_type == "local_code":
original_path = details.get("target_path", "unknown")
local_code.append(
{
"path": original_path,
"workspace_path": workspace_path,
}
)
elif target_type == "web_application":
urls.append(details["target_url"])
task_parts = [] task_parts = []
if scan_type == "repository": if repositories:
repo_url = target["target_repo"] task_parts.append("\n\nRepositories:")
cloned_path = target.get("cloned_repo_path") for repo in repositories:
if repo["workspace_path"]:
if cloned_path: task_parts.append(f"- {repo['url']} (available at: {repo['workspace_path']})")
workspace_path = "/workspace"
task_parts.append(
f"Perform a security assessment of the Git repository: {repo_url}. "
f"The repository has been cloned from '{repo_url}' to '{cloned_path}' "
f"(host path) and then copied to '{workspace_path}' in your environment."
f"Analyze the codebase at: {workspace_path}"
)
else: else:
task_parts.append( task_parts.append(f"- {repo['url']}")
f"Perform a security assessment of the Git repository: {repo_url}"
if local_code:
task_parts.append("\n\nLocal Codebases:")
task_parts.extend(
f"- {code['path']} (available at: {code['workspace_path']})" for code in local_code
) )
elif scan_type == "web_application": if urls:
task_parts.append( task_parts.append("\n\nURLs:")
f"Perform a security assessment of the web application: {target['target_url']}" task_parts.extend(f"- {url}" for url in urls)
)
elif scan_type == "local_code":
original_path = target.get("target_path", "unknown")
workspace_path = "/workspace"
task_parts.append(
f"Perform a security assessment of the local codebase. "
f"The code from '{original_path}' (user host path) has been copied to "
f"'{workspace_path}' in your environment. "
f"Analyze the codebase at: {workspace_path}"
)
else:
task_parts.append(
f"Perform a general security assessment of: {next(iter(target.values()))}"
)
task_description = " ".join(task_parts) task_description = " ".join(task_parts)
if user_instructions: if user_instructions:
task_description += ( task_description += f"\n\nSpecial instructions: {user_instructions}"
f"\n\nSpecial instructions from the system that must be followed: "
f"{user_instructions}"
)
return await self.agent_loop(task=task_description) return await self.agent_loop(task=task_description)

View File

@@ -54,7 +54,7 @@ class BaseAgent(metaclass=AgentMeta):
def __init__(self, config: dict[str, Any]): def __init__(self, config: dict[str, Any]):
self.config = config self.config = config
self.local_source_path = config.get("local_source_path") self.local_sources = config.get("local_sources", [])
self.non_interactive = config.get("non_interactive", False) self.non_interactive = config.get("non_interactive", False)
if "max_iterations" in config: if "max_iterations" in config:
@@ -317,7 +317,7 @@ class BaseAgent(metaclass=AgentMeta):
runtime = get_runtime() runtime = get_runtime()
sandbox_info = await runtime.create_sandbox( sandbox_info = await runtime.create_sandbox(
self.state.agent_id, self.state.sandbox_token, self.local_source_path self.state.agent_id, self.state.sandbox_token, self.local_sources
) )
self.state.sandbox_id = sandbox_info["workspace_id"] self.state.sandbox_id = sandbox_info["workspace_id"]
self.state.sandbox_token = sandbox_info["auth_token"] self.state.sandbox_token = sandbox_info["auth_token"]

View File

@@ -11,6 +11,8 @@ from strix.agents.StrixAgent import StrixAgent
from strix.llm.config import LLMConfig from strix.llm.config import LLMConfig
from strix.telemetry.tracer import Tracer, set_global_tracer from strix.telemetry.tracer import Tracer, set_global_tracer
from .utils import get_severity_color
async def run_cli(args: Any) -> None: # noqa: PLR0915 async def run_cli(args: Any) -> None: # noqa: PLR0915
console = Console() console = Console()
@@ -19,15 +21,18 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
start_text.append("🦉 ", style="bold white") start_text.append("🦉 ", style="bold white")
start_text.append("STRIX CYBERSECURITY AGENT", style="bold green") start_text.append("STRIX CYBERSECURITY AGENT", style="bold green")
target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target
target_text = Text() target_text = Text()
if len(args.targets_info) == 1:
target_text.append("🎯 Target: ", style="bold cyan") target_text.append("🎯 Target: ", style="bold cyan")
target_text.append(str(target_value), style="bold white") target_text.append(args.targets_info[0]["original"], style="bold white")
else:
instructions_text = Text() target_text.append("🎯 Targets: ", style="bold cyan")
if args.instruction: target_text.append(f"{len(args.targets_info)} targets\n", style="bold white")
instructions_text.append("📋 Instructions: ", style="bold cyan") for i, target_info in enumerate(args.targets_info):
instructions_text.append(args.instruction, style="white") target_text.append("", style="dim white")
target_text.append(target_info["original"], style="white")
if i < len(args.targets_info) - 1:
target_text.append("\n")
results_text = Text() results_text = Text()
results_text.append("📊 Results will be saved to: ", style="bold cyan") results_text.append("📊 Results will be saved to: ", style="bold cyan")
@@ -44,8 +49,6 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
start_text, start_text,
"\n\n", "\n\n",
target_text, target_text,
"\n" if args.instruction else "",
instructions_text if args.instruction else "",
"\n", "\n",
results_text, results_text,
note_text, note_text,
@@ -62,8 +65,7 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
scan_config = { scan_config = {
"scan_id": args.run_name, "scan_id": args.run_name,
"scan_type": args.target_type, "targets": args.targets_info,
"target": args.target_dict,
"user_instructions": args.instruction or "", "user_instructions": args.instruction or "",
"run_name": args.run_name, "run_name": args.run_name,
} }
@@ -75,23 +77,14 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
"non_interactive": True, "non_interactive": True,
} }
if args.target_type == "local_code" and "target_path" in args.target_dict: if getattr(args, "local_sources", None):
agent_config["local_source_path"] = args.target_dict["target_path"] agent_config["local_sources"] = args.local_sources
elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict:
agent_config["local_source_path"] = args.target_dict["cloned_repo_path"]
tracer = Tracer(args.run_name) tracer = Tracer(args.run_name)
tracer.set_scan_config(scan_config) tracer.set_scan_config(scan_config)
def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None: def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None:
severity_colors = { severity_color = get_severity_color(severity.lower())
"critical": "#dc2626",
"high": "#ea580c",
"medium": "#d97706",
"low": "#65a30d",
"info": "#0284c7",
}
severity_color = severity_colors.get(severity.lower(), "#6b7280")
vuln_text = Text() vuln_text = Text()
vuln_text.append("🐞 ", style="bold red") vuln_text.append("🐞 ", style="bold red")

View File

@@ -7,16 +7,10 @@ import argparse
import asyncio import asyncio
import logging import logging
import os import os
import secrets
import shutil import shutil
import subprocess
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import docker
import litellm import litellm
from docker.errors import DockerException from docker.errors import DockerException
from rich.console import Console from rich.console import Console
@@ -25,6 +19,19 @@ from rich.text import Text
from strix.interface.cli import run_cli from strix.interface.cli import run_cli
from strix.interface.tui import run_tui from strix.interface.tui import run_tui
from strix.interface.utils import (
assign_workspace_subdirs,
build_llm_stats_text,
build_stats_text,
check_docker_connection,
clone_repository,
collect_local_sources,
generate_run_name,
image_exists,
infer_target_type,
process_pull_line,
validate_llm_response,
)
from strix.runtime.docker_runtime import STRIX_IMAGE from strix.runtime.docker_runtime import STRIX_IMAGE
from strix.telemetry.tracer import get_global_tracer from strix.telemetry.tracer import get_global_tracer
@@ -32,15 +39,6 @@ from strix.telemetry.tracer import get_global_tracer
logging.getLogger().setLevel(logging.ERROR) logging.getLogger().setLevel(logging.ERROR)
def format_token_count(count: float) -> str:
count = int(count)
if count >= 1_000_000:
return f"{count / 1_000_000:.1f}M"
if count >= 1_000:
return f"{count / 1_000:.1f}K"
return str(count)
def validate_environment() -> None: # noqa: PLR0912, PLR0915 def validate_environment() -> None: # noqa: PLR0912, PLR0915
console = Console() console = Console()
missing_required_vars = [] missing_required_vars = []
@@ -163,11 +161,6 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915
sys.exit(1) sys.exit(1)
def _validate_llm_response(response: Any) -> None:
if not response or not response.choices or not response.choices[0].message.content:
raise RuntimeError("Invalid response from LLM")
def check_docker_installed() -> None: def check_docker_installed() -> None:
if shutil.which("docker") is None: if shutil.which("docker") is None:
console = Console() console = Console()
@@ -220,7 +213,7 @@ async def warm_up_llm() -> None:
messages=test_messages, messages=test_messages,
) )
_validate_llm_response(response) validate_llm_response(response)
except Exception as e: # noqa: BLE001 except Exception as e: # noqa: BLE001
error_text = Text() error_text = Text()
@@ -245,141 +238,6 @@ async def warm_up_llm() -> None:
sys.exit(1) sys.exit(1)
def generate_run_name() -> str:
# fmt: off
adjectives = [
"stealthy", "sneaky", "crafty", "elite", "phantom", "shadow", "silent",
"rogue", "covert", "ninja", "ghost", "cyber", "digital", "binary",
"encrypted", "obfuscated", "masked", "cloaked", "invisible", "anonymous"
]
nouns = [
"exploit", "payload", "backdoor", "rootkit", "keylogger", "botnet", "trojan",
"worm", "virus", "packet", "buffer", "shell", "daemon", "spider", "crawler",
"scanner", "sniffer", "honeypot", "firewall", "breach"
]
# fmt: on
adj = secrets.choice(adjectives)
noun = secrets.choice(nouns)
number = secrets.randbelow(900) + 100
return f"{adj}-{noun}-{number}"
def clone_repository(repo_url: str, run_name: str) -> str:
console = Console()
git_executable = shutil.which("git")
if git_executable is None:
raise FileNotFoundError("Git executable not found in PATH")
temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name
temp_dir.mkdir(parents=True, exist_ok=True)
repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name
clone_path = temp_dir / repo_name
if clone_path.exists():
shutil.rmtree(clone_path)
try:
with console.status(f"[bold cyan]Cloning repository {repo_name}...", spinner="dots"):
subprocess.run( # noqa: S603
[
git_executable,
"clone",
repo_url,
str(clone_path),
],
capture_output=True,
text=True,
check=True,
)
return str(clone_path.absolute())
except subprocess.CalledProcessError as e:
error_text = Text()
error_text.append("", style="bold red")
error_text.append("REPOSITORY CLONE FAILED", style="bold red")
error_text.append("\n\n", style="white")
error_text.append(f"Could not clone repository: {repo_url}\n", style="white")
error_text.append(
f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red"
)
panel = Panel(
error_text,
title="[bold red]🛡️ STRIX CLONE ERROR",
title_align="center",
border_style="red",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
console.print()
sys.exit(1)
except FileNotFoundError:
error_text = Text()
error_text.append("", style="bold red")
error_text.append("GIT NOT FOUND", style="bold red")
error_text.append("\n\n", style="white")
error_text.append("Git is not installed or not available in PATH.\n", style="white")
error_text.append("Please install Git to clone repositories.\n", style="white")
panel = Panel(
error_text,
title="[bold red]🛡️ STRIX CLONE ERROR",
title_align="center",
border_style="red",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
console.print()
sys.exit(1)
def infer_target_type(target: str) -> tuple[str, dict[str, str]]:
if not target or not isinstance(target, str):
raise ValueError("Target must be a non-empty string")
target = target.strip()
parsed = urlparse(target)
if parsed.scheme in ("http", "https"):
if any(
host in parsed.netloc.lower() for host in ["github.com", "gitlab.com", "bitbucket.org"]
):
return "repository", {"target_repo": target}
return "web_application", {"target_url": target}
path = Path(target)
try:
if path.exists():
if path.is_dir():
return "local_code", {"target_path": str(path.absolute())}
raise ValueError(f"Path exists but is not a directory: {target}")
except (OSError, RuntimeError) as e:
raise ValueError(f"Invalid path: {target} - {e!s}") from e
if target.startswith("git@") or target.endswith(".git"):
return "repository", {"target_repo": target}
if "." in target and "/" not in target and not target.startswith("."):
parts = target.split(".")
if len(parts) >= 2 and all(p and p.strip() for p in parts):
return "web_application", {"target_url": f"https://{target}"}
raise ValueError(
f"Invalid target: {target}\n"
"Target must be one of:\n"
"- A valid URL (http:// or https://)\n"
"- A Git repository URL (https://github.com/... or git@github.com:...)\n"
"- A local directory path\n"
"- A domain name (e.g., example.com)"
)
def parse_arguments() -> argparse.Namespace: def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Strix Multi-Agent Cybersecurity Penetration Testing Tool", description="Strix Multi-Agent Cybersecurity Penetration Testing Tool",
@@ -399,16 +257,23 @@ Examples:
# Domain penetration test # Domain penetration test
strix --target example.com strix --target example.com
# Multiple targets (e.g., white-box testing with source and deployed app)
strix --target https://github.com/user/repo --target https://example.com
strix --target ./my-project --target https://staging.example.com --target https://prod.example.com
# Custom instructions # Custom instructions
strix --target example.com --instruction "Focus on authentication vulnerabilities" strix --target example.com --instruction "Focus on authentication vulnerabilities"
""", """,
) )
parser.add_argument( parser.add_argument(
"-t",
"--target", "--target",
type=str, type=str,
required=True, required=True,
help="Target to test (URL, repository, local directory path, or domain name)", action="append",
help="Target to test (URL, repository, local directory path, or domain name). "
"Can be specified multiple times for multi-target scans.",
) )
parser.add_argument( parser.add_argument(
"--instruction", "--instruction",
@@ -439,127 +304,53 @@ Examples:
args = parser.parse_args() args = parser.parse_args()
args.targets_info = []
for target in args.target:
try: try:
args.target_type, args.target_dict = infer_target_type(args.target) target_type, target_dict = infer_target_type(target)
except ValueError as e:
parser.error(str(e)) if target_type == "local_code":
display_target = target_dict.get("target_path", target)
else:
display_target = target
args.targets_info.append(
{"type": target_type, "details": target_dict, "original": display_target}
)
except ValueError:
parser.error(f"Invalid target '{target}'")
assign_workspace_subdirs(args.targets_info)
return args return args
def _get_severity_color(severity: str) -> str:
severity_colors = {
"critical": "#dc2626",
"high": "#ea580c",
"medium": "#d97706",
"low": "#65a30d",
"info": "#0284c7",
}
return severity_colors.get(severity, "#6b7280")
def _build_stats_text(tracer: Any) -> Text:
stats_text = Text()
if not tracer:
return stats_text
vuln_count = len(tracer.vulnerability_reports)
tool_count = tracer.get_real_tool_count()
agent_count = len(tracer.agents)
if vuln_count > 0:
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
for report in tracer.vulnerability_reports:
severity = report.get("severity", "").lower()
if severity in severity_counts:
severity_counts[severity] += 1
stats_text.append("🔍 Vulnerabilities Found: ", style="bold red")
severity_parts = []
for severity in ["critical", "high", "medium", "low", "info"]:
count = severity_counts[severity]
if count > 0:
severity_color = _get_severity_color(severity)
severity_text = Text()
severity_text.append(f"{severity.upper()}: ", style=severity_color)
severity_text.append(str(count), style=f"bold {severity_color}")
severity_parts.append(severity_text)
for i, part in enumerate(severity_parts):
stats_text.append(part)
if i < len(severity_parts) - 1:
stats_text.append(" | ", style="dim white")
stats_text.append(" (Total: ", style="dim white")
stats_text.append(str(vuln_count), style="bold yellow")
stats_text.append(")", style="dim white")
stats_text.append("\n")
else:
stats_text.append("🔍 Vulnerabilities Found: ", style="bold green")
stats_text.append("0", style="bold white")
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
stats_text.append("\n")
stats_text.append("🤖 Agents Used: ", style="bold cyan")
stats_text.append(str(agent_count), style="bold white")
stats_text.append("", style="dim white")
stats_text.append("🛠️ Tools Called: ", style="bold cyan")
stats_text.append(str(tool_count), style="bold white")
return stats_text
def _build_llm_stats_text(tracer: Any) -> Text:
llm_stats_text = Text()
if not tracer:
return llm_stats_text
llm_stats = tracer.get_total_llm_stats()
total_stats = llm_stats["total"]
if total_stats["requests"] > 0:
llm_stats_text.append("📥 Input Tokens: ", style="bold cyan")
llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
if total_stats["cached_tokens"] > 0:
llm_stats_text.append("", style="dim white")
llm_stats_text.append("⚡ Cached: ", style="bold green")
llm_stats_text.append(
format_token_count(total_stats["cached_tokens"]), style="bold green"
)
llm_stats_text.append("", style="dim white")
llm_stats_text.append("📤 Output Tokens: ", style="bold cyan")
llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
if total_stats["cost"] > 0:
llm_stats_text.append("", style="dim white")
llm_stats_text.append("💰 Total Cost: $", style="bold cyan")
llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow")
return llm_stats_text
def display_completion_message(args: argparse.Namespace, results_path: Path) -> None: def display_completion_message(args: argparse.Namespace, results_path: Path) -> None:
console = Console() console = Console()
tracer = get_global_tracer() tracer = get_global_tracer()
target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target
completion_text = Text() completion_text = Text()
completion_text.append("🦉 ", style="bold white") completion_text.append("🦉 ", style="bold white")
completion_text.append("AGENT FINISHED", style="bold green") completion_text.append("AGENT FINISHED", style="bold green")
completion_text.append("", style="dim white") completion_text.append("", style="dim white")
completion_text.append("Penetration test completed", style="white") completion_text.append("Penetration test completed", style="white")
stats_text = _build_stats_text(tracer) stats_text = build_stats_text(tracer)
llm_stats_text = _build_llm_stats_text(tracer) llm_stats_text = build_llm_stats_text(tracer)
target_text = Text() target_text = Text()
if len(args.targets_info) == 1:
target_text.append("🎯 Target: ", style="bold cyan") target_text.append("🎯 Target: ", style="bold cyan")
target_text.append(str(target_value), style="bold white") target_text.append(args.targets_info[0]["original"], style="bold white")
else:
target_text.append("🎯 Targets: ", style="bold cyan")
target_text.append(f"{len(args.targets_info)} targets\n", style="bold white")
for i, target_info in enumerate(args.targets_info):
target_text.append("", style="dim white")
target_text.append(target_info["original"], style="white")
if i < len(args.targets_info) - 1:
target_text.append("\n")
results_text = Text() results_text = Text()
results_text.append("📊 Results Saved To: ", style="bold cyan") results_text.append("📊 Results Saved To: ", style="bold cyan")
@@ -575,19 +366,19 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
stats_text, stats_text,
"\n", "\n",
llm_stats_text, llm_stats_text,
"\n", "\n\n",
results_text, results_text,
) )
else: else:
panel_content = Text.assemble( panel_content = Text.assemble(
completion_text, "\n\n", target_text, "\n", stats_text, "\n", results_text completion_text, "\n\n", target_text, "\n", stats_text, "\n\n", results_text
) )
elif llm_stats_text.plain: elif llm_stats_text.plain:
panel_content = Text.assemble( panel_content = Text.assemble(
completion_text, "\n\n", target_text, "\n", llm_stats_text, "\n", results_text completion_text, "\n\n", target_text, "\n", llm_stats_text, "\n\n", results_text
) )
else: else:
panel_content = Text.assemble(completion_text, "\n\n", target_text, "\n", results_text) panel_content = Text.assemble(completion_text, "\n\n", target_text, "\n\n", results_text)
panel = Panel( panel = Panel(
panel_content, panel_content,
@@ -602,86 +393,11 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
console.print() console.print()
def _check_docker_connection() -> Any:
try:
return docker.from_env()
except DockerException:
console = Console()
error_text = Text()
error_text.append("", style="bold red")
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
error_text.append("\n\n", style="white")
error_text.append("Cannot connect to Docker daemon.\n", style="white")
error_text.append("Please ensure Docker is installed and running.\n\n", style="white")
error_text.append("Try running: ", style="dim white")
error_text.append("sudo systemctl start docker", style="dim cyan")
panel = Panel(
error_text,
title="[bold red]🛡️ STRIX STARTUP ERROR",
title_align="center",
border_style="red",
padding=(1, 2),
)
console.print("\n", panel, "\n")
sys.exit(1)
def _image_exists(client: Any) -> bool:
try:
client.images.get(STRIX_IMAGE)
except docker.errors.ImageNotFound:
return False
else:
return True
def _update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None:
if "Pull complete" in layer_status or "Already exists" in layer_status:
layers_info[layer_id] = ""
elif "Downloading" in layer_status:
layers_info[layer_id] = ""
elif "Extracting" in layer_status:
layers_info[layer_id] = "📦"
elif "Waiting" in layer_status:
layers_info[layer_id] = ""
else:
layers_info[layer_id] = ""
def _process_pull_line(
line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str
) -> str:
if "id" in line and "status" in line:
layer_id = line["id"]
_update_layer_status(layers_info, layer_id, line["status"])
completed = sum(1 for v in layers_info.values() if v == "")
total = len(layers_info)
if total > 0:
update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete"
if update_msg != last_update:
status.update(update_msg)
return update_msg
elif "status" in line and "id" not in line:
global_status = line["status"]
if "Pulling from" in global_status:
status.update("[bold cyan]Fetching image manifest...")
elif "Digest:" in global_status:
status.update("[bold cyan]Verifying image...")
elif "Status:" in global_status:
status.update("[bold cyan]Finalizing...")
return last_update
def pull_docker_image() -> None: def pull_docker_image() -> None:
console = Console() console = Console()
client = _check_docker_connection() client = check_docker_connection()
if _image_exists(client): if image_exists(client, STRIX_IMAGE):
return return
console.print() console.print()
@@ -695,7 +411,7 @@ def pull_docker_image() -> None:
last_update = "" last_update = ""
for line in client.api.pull(STRIX_IMAGE, stream=True, decode=True): for line in client.api.pull(STRIX_IMAGE, stream=True, decode=True):
last_update = _process_pull_line(line, layers_info, status, last_update) last_update = process_pull_line(line, layers_info, status, last_update)
except DockerException as e: except DockerException as e:
console.print() console.print()
@@ -738,11 +454,14 @@ def main() -> None:
if not args.run_name: if not args.run_name:
args.run_name = generate_run_name() args.run_name = generate_run_name()
if args.target_type == "repository": for target_info in args.targets_info:
repo_url = args.target_dict["target_repo"] if target_info["type"] == "repository":
cloned_path = clone_repository(repo_url, args.run_name) repo_url = target_info["details"]["target_repo"]
dest_name = target_info["details"].get("workspace_subdir")
cloned_path = clone_repository(repo_url, args.run_name, dest_name)
target_info["details"]["cloned_repo_path"] = cloned_path
args.target_dict["cloned_repo_path"] = cloned_path args.local_sources = collect_local_sources(args.targets_info)
if args.non_interactive: if args.non_interactive:
asyncio.run(run_cli(args)) asyncio.run(run_cli(args))

View File

@@ -16,23 +16,28 @@ class ScanStartInfoRenderer(BaseToolRenderer):
args = tool_data.get("args", {}) args = tool_data.get("args", {})
status = tool_data.get("status", "unknown") status = tool_data.get("status", "unknown")
target = args.get("target", {}) targets = args.get("targets", [])
target_display = cls._build_target_display(target) if len(targets) == 1:
target_display = cls._build_single_target_display(targets[0])
content = f"🚀 Starting scan on {target_display}" content = f"🚀 Starting penetration test on {target_display}"
elif len(targets) > 1:
content = f"🚀 Starting penetration test on {len(targets)} targets"
for target_info in targets:
target_display = cls._build_single_target_display(target_info)
content += f"\n{target_display}"
else:
content = "🚀 Starting penetration test"
css_classes = cls.get_css_classes(status) css_classes = cls.get_css_classes(status)
return Static(content, classes=css_classes) return Static(content, classes=css_classes)
@classmethod @classmethod
def _build_target_display(cls, target: dict[str, Any]) -> str: def _build_single_target_display(cls, target_info: dict[str, Any]) -> str:
if target_url := target.get("target_url"): original = target_info.get("original")
return cls.escape_markup(str(target_url)) if original:
if target_repo := target.get("target_repo"): return cls.escape_markup(str(original))
return cls.escape_markup(str(target_repo))
if target_path := target.get("target_path"):
return cls.escape_markup(str(target_path))
return "unknown target" return "unknown target"

View File

@@ -312,8 +312,7 @@ class StrixTUIApp(App): # type: ignore[misc]
def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]: def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]:
return { return {
"scan_id": args.run_name, "scan_id": args.run_name,
"scan_type": args.target_type, "targets": args.targets_info,
"target": args.target_dict,
"user_instructions": args.instruction or "", "user_instructions": args.instruction or "",
"run_name": args.run_name, "run_name": args.run_name,
} }
@@ -326,10 +325,8 @@ class StrixTUIApp(App): # type: ignore[misc]
"max_iterations": 300, "max_iterations": 300,
} }
if args.target_type == "local_code" and "target_path" in args.target_dict: if getattr(args, "local_sources", None):
config["local_source_path"] = args.target_dict["target_path"] config["local_sources"] = args.local_sources
elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict:
config["local_source_path"] = args.target_dict["cloned_repo_path"]
return config return config

434
strix/interface/utils.py Normal file
View File

@@ -0,0 +1,434 @@
import re
import secrets
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import docker
from docker.errors import DockerException, ImageNotFound
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
# Token formatting utilities
def format_token_count(count: float) -> str:
count = int(count)
if count >= 1_000_000:
return f"{count / 1_000_000:.1f}M"
if count >= 1_000:
return f"{count / 1_000:.1f}K"
return str(count)
# Display utilities
def get_severity_color(severity: str) -> str:
severity_colors = {
"critical": "#dc2626",
"high": "#ea580c",
"medium": "#d97706",
"low": "#65a30d",
"info": "#0284c7",
}
return severity_colors.get(severity, "#6b7280")
def build_stats_text(tracer: Any) -> Text:
stats_text = Text()
if not tracer:
return stats_text
vuln_count = len(tracer.vulnerability_reports)
tool_count = tracer.get_real_tool_count()
agent_count = len(tracer.agents)
if vuln_count > 0:
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
for report in tracer.vulnerability_reports:
severity = report.get("severity", "").lower()
if severity in severity_counts:
severity_counts[severity] += 1
stats_text.append("🔍 Vulnerabilities Found: ", style="bold red")
severity_parts = []
for severity in ["critical", "high", "medium", "low", "info"]:
count = severity_counts[severity]
if count > 0:
severity_color = get_severity_color(severity)
severity_text = Text()
severity_text.append(f"{severity.upper()}: ", style=severity_color)
severity_text.append(str(count), style=f"bold {severity_color}")
severity_parts.append(severity_text)
for i, part in enumerate(severity_parts):
stats_text.append(part)
if i < len(severity_parts) - 1:
stats_text.append(" | ", style="dim white")
stats_text.append(" (Total: ", style="dim white")
stats_text.append(str(vuln_count), style="bold yellow")
stats_text.append(")", style="dim white")
stats_text.append("\n")
else:
stats_text.append("🔍 Vulnerabilities Found: ", style="bold green")
stats_text.append("0", style="bold white")
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
stats_text.append("\n")
stats_text.append("🤖 Agents Used: ", style="bold cyan")
stats_text.append(str(agent_count), style="bold white")
stats_text.append("", style="dim white")
stats_text.append("🛠️ Tools Called: ", style="bold cyan")
stats_text.append(str(tool_count), style="bold white")
return stats_text
def build_llm_stats_text(tracer: Any) -> Text:
llm_stats_text = Text()
if not tracer:
return llm_stats_text
llm_stats = tracer.get_total_llm_stats()
total_stats = llm_stats["total"]
if total_stats["requests"] > 0:
llm_stats_text.append("📥 Input Tokens: ", style="bold cyan")
llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
if total_stats["cached_tokens"] > 0:
llm_stats_text.append("", style="dim white")
llm_stats_text.append("⚡ Cached: ", style="bold green")
llm_stats_text.append(
format_token_count(total_stats["cached_tokens"]), style="bold green"
)
llm_stats_text.append("", style="dim white")
llm_stats_text.append("📤 Output Tokens: ", style="bold cyan")
llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
if total_stats["cost"] > 0:
llm_stats_text.append("", style="dim white")
llm_stats_text.append("💰 Total Cost: $", style="bold cyan")
llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow")
return llm_stats_text
# Name generation utilities
def generate_run_name() -> str:
# fmt: off
adjectives = [
"stealthy", "sneaky", "crafty", "elite", "phantom", "shadow", "silent",
"rogue", "covert", "ninja", "ghost", "cyber", "digital", "binary",
"encrypted", "obfuscated", "masked", "cloaked", "invisible", "anonymous"
]
nouns = [
"exploit", "payload", "backdoor", "rootkit", "keylogger", "botnet", "trojan",
"worm", "virus", "packet", "buffer", "shell", "daemon", "spider", "crawler",
"scanner", "sniffer", "honeypot", "firewall", "breach"
]
# fmt: on
adj = secrets.choice(adjectives)
noun = secrets.choice(nouns)
number = secrets.randbelow(900) + 100
return f"{adj}-{noun}-{number}"
# Target processing utilities
def infer_target_type(target: str) -> tuple[str, dict[str, str]]:
if not target or not isinstance(target, str):
raise ValueError("Target must be a non-empty string")
target = target.strip()
lower_target = target.lower()
bare_repo_prefixes = (
"github.com/",
"www.github.com/",
"gitlab.com/",
"www.gitlab.com/",
"bitbucket.org/",
"www.bitbucket.org/",
)
if any(lower_target.startswith(p) for p in bare_repo_prefixes):
return "repository", {"target_repo": f"https://{target}"}
parsed = urlparse(target)
if parsed.scheme in ("http", "https"):
if any(
host in parsed.netloc.lower() for host in ["github.com", "gitlab.com", "bitbucket.org"]
):
return "repository", {"target_repo": target}
return "web_application", {"target_url": target}
path = Path(target).expanduser()
try:
if path.exists():
if path.is_dir():
resolved = path.resolve()
return "local_code", {"target_path": str(resolved)}
raise ValueError(f"Path exists but is not a directory: {target}")
except (OSError, RuntimeError) as e:
raise ValueError(f"Invalid path: {target} - {e!s}") from e
if target.startswith("git@") or target.endswith(".git"):
return "repository", {"target_repo": target}
if "." in target and "/" not in target and not target.startswith("."):
parts = target.split(".")
if len(parts) >= 2 and all(p and p.strip() for p in parts):
return "web_application", {"target_url": f"https://{target}"}
raise ValueError(
f"Invalid target: {target}\n"
"Target must be one of:\n"
"- A valid URL (http:// or https://)\n"
"- A Git repository URL (https://github.com/... or git@github.com:...)\n"
"- A local directory path\n"
"- A domain name (e.g., example.com)"
)
def sanitize_name(name: str) -> str:
sanitized = re.sub(r"[^A-Za-z0-9._-]", "-", name.strip())
return sanitized or "target"
def derive_repo_base_name(repo_url: str) -> str:
if repo_url.endswith("/"):
repo_url = repo_url[:-1]
if ":" in repo_url and repo_url.startswith("git@"):
path_part = repo_url.split(":", 1)[1]
else:
path_part = urlparse(repo_url).path or repo_url
candidate = path_part.split("/")[-1]
if candidate.endswith(".git"):
candidate = candidate[:-4]
return sanitize_name(candidate or "repository")
def derive_local_base_name(path_str: str) -> str:
try:
base = Path(path_str).resolve().name
except (OSError, RuntimeError):
base = Path(path_str).name
return sanitize_name(base or "workspace")
def assign_workspace_subdirs(targets_info: list[dict[str, Any]]) -> None:
name_counts: dict[str, int] = {}
for target in targets_info:
target_type = target["type"]
details = target["details"]
base_name: str | None = None
if target_type == "repository":
base_name = derive_repo_base_name(details["target_repo"])
elif target_type == "local_code":
base_name = derive_local_base_name(details.get("target_path", "local"))
if base_name is None:
continue
count = name_counts.get(base_name, 0) + 1
name_counts[base_name] = count
workspace_subdir = base_name if count == 1 else f"{base_name}-{count}"
details["workspace_subdir"] = workspace_subdir
def collect_local_sources(targets_info: list[dict[str, Any]]) -> list[dict[str, str]]:
local_sources: list[dict[str, str]] = []
for target_info in targets_info:
details = target_info["details"]
workspace_subdir = details.get("workspace_subdir")
if target_info["type"] == "local_code" and "target_path" in details:
local_sources.append(
{
"source_path": details["target_path"],
"workspace_subdir": workspace_subdir,
}
)
elif target_info["type"] == "repository" and "cloned_repo_path" in details:
local_sources.append(
{
"source_path": details["cloned_repo_path"],
"workspace_subdir": workspace_subdir,
}
)
return local_sources
# Repository utilities
def clone_repository(repo_url: str, run_name: str, dest_name: str | None = None) -> str:
console = Console()
git_executable = shutil.which("git")
if git_executable is None:
raise FileNotFoundError("Git executable not found in PATH")
temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name
temp_dir.mkdir(parents=True, exist_ok=True)
if dest_name:
repo_name = dest_name
else:
repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name
clone_path = temp_dir / repo_name
if clone_path.exists():
shutil.rmtree(clone_path)
try:
with console.status(f"[bold cyan]Cloning repository {repo_url}...", spinner="dots"):
subprocess.run( # noqa: S603
[
git_executable,
"clone",
repo_url,
str(clone_path),
],
capture_output=True,
text=True,
check=True,
)
return str(clone_path.absolute())
except subprocess.CalledProcessError as e:
error_text = Text()
error_text.append("", style="bold red")
error_text.append("REPOSITORY CLONE FAILED", style="bold red")
error_text.append("\n\n", style="white")
error_text.append(f"Could not clone repository: {repo_url}\n", style="white")
error_text.append(
f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red"
)
panel = Panel(
error_text,
title="[bold red]🛡️ STRIX CLONE ERROR",
title_align="center",
border_style="red",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
console.print()
raise
except FileNotFoundError:
error_text = Text()
error_text.append("", style="bold red")
error_text.append("GIT NOT FOUND", style="bold red")
error_text.append("\n\n", style="white")
error_text.append("Git is not installed or not available in PATH.\n", style="white")
error_text.append("Please install Git to clone repositories.\n", style="white")
panel = Panel(
error_text,
title="[bold red]🛡️ STRIX CLONE ERROR",
title_align="center",
border_style="red",
padding=(1, 2),
)
console.print("\n")
console.print(panel)
console.print()
raise
# Docker utilities
def check_docker_connection() -> Any:
try:
return docker.from_env()
except DockerException:
console = Console()
error_text = Text()
error_text.append("", style="bold red")
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
error_text.append("\n\n", style="white")
error_text.append("Cannot connect to Docker daemon.\n", style="white")
error_text.append("Please ensure Docker is installed and running.\n\n", style="white")
error_text.append("Try running: ", style="dim white")
error_text.append("sudo systemctl start docker", style="dim cyan")
panel = Panel(
error_text,
title="[bold red]🛡️ STRIX STARTUP ERROR",
title_align="center",
border_style="red",
padding=(1, 2),
)
console.print("\n", panel, "\n")
raise RuntimeError("Docker not available") from None
def image_exists(client: Any, image_name: str) -> bool:
try:
client.images.get(image_name)
except ImageNotFound:
return False
else:
return True
def update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None:
if "Pull complete" in layer_status or "Already exists" in layer_status:
layers_info[layer_id] = ""
elif "Downloading" in layer_status:
layers_info[layer_id] = ""
elif "Extracting" in layer_status:
layers_info[layer_id] = "📦"
elif "Waiting" in layer_status:
layers_info[layer_id] = ""
else:
layers_info[layer_id] = ""
def process_pull_line(
line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str
) -> str:
if "id" in line and "status" in line:
layer_id = line["id"]
update_layer_status(layers_info, layer_id, line["status"])
completed = sum(1 for v in layers_info.values() if v == "")
total = len(layers_info)
if total > 0:
update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete"
if update_msg != last_update:
status.update(update_msg)
return update_msg
elif "status" in line and "id" not in line:
global_status = line["status"]
if "Pulling from" in global_status:
status.update("[bold cyan]Fetching image manifest...")
elif "Digest:" in global_status:
status.update("[bold cyan]Verifying image...")
elif "Status:" in global_status:
status.update("[bold cyan]Finalizing...")
return last_update
# LLM utilities
def validate_llm_response(response: Any) -> None:
if not response or not response.choices or not response.choices[0].message.content:
raise RuntimeError("Invalid response from LLM")

View File

@@ -250,7 +250,9 @@ class DockerRuntime(AbstractRuntime):
time.sleep(5) time.sleep(5)
def _copy_local_directory_to_container(self, container: Container, local_path: str) -> None: def _copy_local_directory_to_container(
self, container: Container, local_path: str, target_name: str | None = None
) -> None:
import tarfile import tarfile
from io import BytesIO from io import BytesIO
@@ -260,13 +262,20 @@ class DockerRuntime(AbstractRuntime):
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}") logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
return return
if target_name:
logger.info(
f"Copying local directory {local_path_obj} to container at "
f"/workspace/{target_name}"
)
else:
logger.info(f"Copying local directory {local_path_obj} to container") logger.info(f"Copying local directory {local_path_obj} to container")
tar_buffer = BytesIO() tar_buffer = BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w") as tar: with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
for item in local_path_obj.rglob("*"): for item in local_path_obj.rglob("*"):
if item.is_file(): if item.is_file():
arcname = item.relative_to(local_path_obj) rel_path = item.relative_to(local_path_obj)
arcname = Path(target_name) / rel_path if target_name else rel_path
tar.add(item, arcname=arcname) tar.add(item, arcname=arcname)
tar_buffer.seek(0) tar_buffer.seek(0)
@@ -283,14 +292,26 @@ class DockerRuntime(AbstractRuntime):
logger.exception("Failed to copy local directory to container") logger.exception("Failed to copy local directory to container")
async def create_sandbox( async def create_sandbox(
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None self,
agent_id: str,
existing_token: str | None = None,
local_sources: list[dict[str, str]] | None = None,
) -> SandboxInfo: ) -> SandboxInfo:
scan_id = self._get_scan_id(agent_id) scan_id = self._get_scan_id(agent_id)
container = self._get_or_create_scan_container(scan_id) container = self._get_or_create_scan_container(scan_id)
source_copied_key = f"_source_copied_{scan_id}" source_copied_key = f"_source_copied_{scan_id}"
if local_source_path and not hasattr(self, source_copied_key): if local_sources and not hasattr(self, source_copied_key):
self._copy_local_directory_to_container(container, local_source_path) for index, source in enumerate(local_sources, start=1):
source_path = source.get("source_path")
if not source_path:
continue
target_name = source.get("workspace_subdir")
if not target_name:
target_name = Path(source_path).name or f"target_{index}"
self._copy_local_directory_to_container(container, source_path, target_name)
setattr(self, source_copied_key, True) setattr(self, source_copied_key, True)
container_id = container.id container_id = container.id

View File

@@ -13,7 +13,10 @@ class SandboxInfo(TypedDict):
class AbstractRuntime(ABC): class AbstractRuntime(ABC):
@abstractmethod @abstractmethod
async def create_sandbox( async def create_sandbox(
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None self,
agent_id: str,
existing_token: str | None = None,
local_sources: list[dict[str, str]] | None = None,
) -> SandboxInfo: ) -> SandboxInfo:
raise NotImplementedError raise NotImplementedError

View File

@@ -44,8 +44,7 @@ class Tracer:
"run_name": self.run_name, "run_name": self.run_name,
"start_time": self.start_time, "start_time": self.start_time,
"end_time": None, "end_time": None,
"target": None, "targets": [],
"scan_type": None,
"status": "running", "status": "running",
} }
self._run_dir: Path | None = None self._run_dir: Path | None = None
@@ -193,8 +192,7 @@ class Tracer:
self.scan_config = config self.scan_config = config
self.run_metadata.update( self.run_metadata.update(
{ {
"target": config.get("target", {}), "targets": config.get("targets", []),
"scan_type": config.get("scan_type", "general"),
"user_instructions": config.get("user_instructions", ""), "user_instructions": config.get("user_instructions", ""),
"max_iterations": config.get("max_iterations", 200), "max_iterations": config.get("max_iterations", 200),
} }