From 738fdc2d49a3b9b0d108267d78325d53893dfd66 Mon Sep 17 00:00:00 2001 From: Ahmed Allam Date: Sat, 1 Nov 2025 01:25:07 +0200 Subject: [PATCH] feat: implement multi-target scanning --- strix/agents/StrixAgent/strix_agent.py | 87 ++-- strix/agents/base_agent.py | 4 +- strix/interface/cli.py | 41 +- strix/interface/main.py | 413 +++-------------- .../tool_components/scan_info_renderer.py | 27 +- strix/interface/tui.py | 9 +- strix/interface/utils.py | 434 ++++++++++++++++++ strix/runtime/docker_runtime.py | 33 +- strix/runtime/runtime.py | 5 +- strix/telemetry/tracer.py | 6 +- 10 files changed, 619 insertions(+), 440 deletions(-) create mode 100644 strix/interface/utils.py diff --git a/strix/agents/StrixAgent/strix_agent.py b/strix/agents/StrixAgent/strix_agent.py index bac3018..07fc8e6 100644 --- a/strix/agents/StrixAgent/strix_agent.py +++ b/strix/agents/StrixAgent/strix_agent.py @@ -19,55 +19,64 @@ class StrixAgent(BaseAgent): super().__init__(config) async def execute_scan(self, scan_config: dict[str, Any]) -> dict[str, Any]: - scan_type = scan_config.get("scan_type", "general") - target = scan_config.get("target", {}) user_instructions = scan_config.get("user_instructions", "") + targets = scan_config.get("targets", []) + + repositories = [] + local_code = [] + urls = [] + + for target in targets: + target_type = target["type"] + details = target["details"] + workspace_subdir = details.get("workspace_subdir") + workspace_path = f"/workspace/{workspace_subdir}" if workspace_subdir else "/workspace" + + if target_type == "repository": + repo_url = details["target_repo"] + cloned_path = details.get("cloned_repo_path") + repositories.append( + { + "url": repo_url, + "workspace_path": workspace_path if cloned_path else None, + } + ) + + elif target_type == "local_code": + original_path = details.get("target_path", "unknown") + local_code.append( + { + "path": original_path, + "workspace_path": workspace_path, + } + ) + + elif target_type == "web_application": + urls.append(details["target_url"]) task_parts = [] - if scan_type == "repository": - repo_url = target["target_repo"] - cloned_path = target.get("cloned_repo_path") + if repositories: + task_parts.append("\n\nRepositories:") + for repo in repositories: + if repo["workspace_path"]: + task_parts.append(f"- {repo['url']} (available at: {repo['workspace_path']})") + else: + task_parts.append(f"- {repo['url']}") - if cloned_path: - workspace_path = "/workspace" - task_parts.append( - f"Perform a security assessment of the Git repository: {repo_url}. " - f"The repository has been cloned from '{repo_url}' to '{cloned_path}' " - f"(host path) and then copied to '{workspace_path}' in your environment." - f"Analyze the codebase at: {workspace_path}" - ) - else: - task_parts.append( - f"Perform a security assessment of the Git repository: {repo_url}" - ) - - elif scan_type == "web_application": - task_parts.append( - f"Perform a security assessment of the web application: {target['target_url']}" + if local_code: + task_parts.append("\n\nLocal Codebases:") + task_parts.extend( + f"- {code['path']} (available at: {code['workspace_path']})" for code in local_code ) - elif scan_type == "local_code": - original_path = target.get("target_path", "unknown") - workspace_path = "/workspace" - task_parts.append( - f"Perform a security assessment of the local codebase. " - f"The code from '{original_path}' (user host path) has been copied to " - f"'{workspace_path}' in your environment. " - f"Analyze the codebase at: {workspace_path}" - ) - - else: - task_parts.append( - f"Perform a general security assessment of: {next(iter(target.values()))}" - ) + if urls: + task_parts.append("\n\nURLs:") + task_parts.extend(f"- {url}" for url in urls) task_description = " ".join(task_parts) if user_instructions: - task_description += ( - f"\n\nSpecial instructions from the system that must be followed: " - f"{user_instructions}" - ) + task_description += f"\n\nSpecial instructions: {user_instructions}" return await self.agent_loop(task=task_description) diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 72f51a5..7591388 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -54,7 +54,7 @@ class BaseAgent(metaclass=AgentMeta): def __init__(self, config: dict[str, Any]): self.config = config - self.local_source_path = config.get("local_source_path") + self.local_sources = config.get("local_sources", []) self.non_interactive = config.get("non_interactive", False) if "max_iterations" in config: @@ -317,7 +317,7 @@ class BaseAgent(metaclass=AgentMeta): runtime = get_runtime() sandbox_info = await runtime.create_sandbox( - self.state.agent_id, self.state.sandbox_token, self.local_source_path + self.state.agent_id, self.state.sandbox_token, self.local_sources ) self.state.sandbox_id = sandbox_info["workspace_id"] self.state.sandbox_token = sandbox_info["auth_token"] diff --git a/strix/interface/cli.py b/strix/interface/cli.py index 4824be7..1dc5523 100644 --- a/strix/interface/cli.py +++ b/strix/interface/cli.py @@ -11,6 +11,8 @@ from strix.agents.StrixAgent import StrixAgent from strix.llm.config import LLMConfig from strix.telemetry.tracer import Tracer, set_global_tracer +from .utils import get_severity_color + async def run_cli(args: Any) -> None: # noqa: PLR0915 console = Console() @@ -19,15 +21,18 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 start_text.append("🦉 ", style="bold white") start_text.append("STRIX CYBERSECURITY AGENT", style="bold green") - target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target target_text = Text() - target_text.append("🎯 Target: ", style="bold cyan") - target_text.append(str(target_value), style="bold white") - - instructions_text = Text() - if args.instruction: - instructions_text.append("📋 Instructions: ", style="bold cyan") - instructions_text.append(args.instruction, style="white") + if len(args.targets_info) == 1: + target_text.append("🎯 Target: ", style="bold cyan") + target_text.append(args.targets_info[0]["original"], style="bold white") + else: + target_text.append("🎯 Targets: ", style="bold cyan") + target_text.append(f"{len(args.targets_info)} targets\n", style="bold white") + for i, target_info in enumerate(args.targets_info): + target_text.append(" • ", style="dim white") + target_text.append(target_info["original"], style="white") + if i < len(args.targets_info) - 1: + target_text.append("\n") results_text = Text() results_text.append("📊 Results will be saved to: ", style="bold cyan") @@ -44,8 +49,6 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 start_text, "\n\n", target_text, - "\n" if args.instruction else "", - instructions_text if args.instruction else "", "\n", results_text, note_text, @@ -62,8 +65,7 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 scan_config = { "scan_id": args.run_name, - "scan_type": args.target_type, - "target": args.target_dict, + "targets": args.targets_info, "user_instructions": args.instruction or "", "run_name": args.run_name, } @@ -75,23 +77,14 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 "non_interactive": True, } - if args.target_type == "local_code" and "target_path" in args.target_dict: - agent_config["local_source_path"] = args.target_dict["target_path"] - elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict: - agent_config["local_source_path"] = args.target_dict["cloned_repo_path"] + if getattr(args, "local_sources", None): + agent_config["local_sources"] = args.local_sources tracer = Tracer(args.run_name) tracer.set_scan_config(scan_config) def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None: - severity_colors = { - "critical": "#dc2626", - "high": "#ea580c", - "medium": "#d97706", - "low": "#65a30d", - "info": "#0284c7", - } - severity_color = severity_colors.get(severity.lower(), "#6b7280") + severity_color = get_severity_color(severity.lower()) vuln_text = Text() vuln_text.append("🐞 ", style="bold red") diff --git a/strix/interface/main.py b/strix/interface/main.py index 6c6e64d..f5ada04 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -7,16 +7,10 @@ import argparse import asyncio import logging import os -import secrets import shutil -import subprocess import sys -import tempfile from pathlib import Path -from typing import Any -from urllib.parse import urlparse -import docker import litellm from docker.errors import DockerException from rich.console import Console @@ -25,6 +19,19 @@ from rich.text import Text from strix.interface.cli import run_cli from strix.interface.tui import run_tui +from strix.interface.utils import ( + assign_workspace_subdirs, + build_llm_stats_text, + build_stats_text, + check_docker_connection, + clone_repository, + collect_local_sources, + generate_run_name, + image_exists, + infer_target_type, + process_pull_line, + validate_llm_response, +) from strix.runtime.docker_runtime import STRIX_IMAGE from strix.telemetry.tracer import get_global_tracer @@ -32,15 +39,6 @@ from strix.telemetry.tracer import get_global_tracer logging.getLogger().setLevel(logging.ERROR) -def format_token_count(count: float) -> str: - count = int(count) - if count >= 1_000_000: - return f"{count / 1_000_000:.1f}M" - if count >= 1_000: - return f"{count / 1_000:.1f}K" - return str(count) - - def validate_environment() -> None: # noqa: PLR0912, PLR0915 console = Console() missing_required_vars = [] @@ -163,11 +161,6 @@ def validate_environment() -> None: # noqa: PLR0912, PLR0915 sys.exit(1) -def _validate_llm_response(response: Any) -> None: - if not response or not response.choices or not response.choices[0].message.content: - raise RuntimeError("Invalid response from LLM") - - def check_docker_installed() -> None: if shutil.which("docker") is None: console = Console() @@ -220,7 +213,7 @@ async def warm_up_llm() -> None: messages=test_messages, ) - _validate_llm_response(response) + validate_llm_response(response) except Exception as e: # noqa: BLE001 error_text = Text() @@ -245,141 +238,6 @@ async def warm_up_llm() -> None: sys.exit(1) -def generate_run_name() -> str: - # fmt: off - adjectives = [ - "stealthy", "sneaky", "crafty", "elite", "phantom", "shadow", "silent", - "rogue", "covert", "ninja", "ghost", "cyber", "digital", "binary", - "encrypted", "obfuscated", "masked", "cloaked", "invisible", "anonymous" - ] - nouns = [ - "exploit", "payload", "backdoor", "rootkit", "keylogger", "botnet", "trojan", - "worm", "virus", "packet", "buffer", "shell", "daemon", "spider", "crawler", - "scanner", "sniffer", "honeypot", "firewall", "breach" - ] - # fmt: on - adj = secrets.choice(adjectives) - noun = secrets.choice(nouns) - number = secrets.randbelow(900) + 100 - return f"{adj}-{noun}-{number}" - - -def clone_repository(repo_url: str, run_name: str) -> str: - console = Console() - - git_executable = shutil.which("git") - if git_executable is None: - raise FileNotFoundError("Git executable not found in PATH") - - temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name - temp_dir.mkdir(parents=True, exist_ok=True) - - repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name - - clone_path = temp_dir / repo_name - - if clone_path.exists(): - shutil.rmtree(clone_path) - - try: - with console.status(f"[bold cyan]Cloning repository {repo_name}...", spinner="dots"): - subprocess.run( # noqa: S603 - [ - git_executable, - "clone", - repo_url, - str(clone_path), - ], - capture_output=True, - text=True, - check=True, - ) - - return str(clone_path.absolute()) - - except subprocess.CalledProcessError as e: - error_text = Text() - error_text.append("❌ ", style="bold red") - error_text.append("REPOSITORY CLONE FAILED", style="bold red") - error_text.append("\n\n", style="white") - error_text.append(f"Could not clone repository: {repo_url}\n", style="white") - error_text.append( - f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red" - ) - - panel = Panel( - error_text, - title="[bold red]🛡️ STRIX CLONE ERROR", - title_align="center", - border_style="red", - padding=(1, 2), - ) - console.print("\n") - console.print(panel) - console.print() - sys.exit(1) - except FileNotFoundError: - error_text = Text() - error_text.append("❌ ", style="bold red") - error_text.append("GIT NOT FOUND", style="bold red") - error_text.append("\n\n", style="white") - error_text.append("Git is not installed or not available in PATH.\n", style="white") - error_text.append("Please install Git to clone repositories.\n", style="white") - - panel = Panel( - error_text, - title="[bold red]🛡️ STRIX CLONE ERROR", - title_align="center", - border_style="red", - padding=(1, 2), - ) - console.print("\n") - console.print(panel) - console.print() - sys.exit(1) - - -def infer_target_type(target: str) -> tuple[str, dict[str, str]]: - if not target or not isinstance(target, str): - raise ValueError("Target must be a non-empty string") - - target = target.strip() - - parsed = urlparse(target) - if parsed.scheme in ("http", "https"): - if any( - host in parsed.netloc.lower() for host in ["github.com", "gitlab.com", "bitbucket.org"] - ): - return "repository", {"target_repo": target} - return "web_application", {"target_url": target} - - path = Path(target) - try: - if path.exists(): - if path.is_dir(): - return "local_code", {"target_path": str(path.absolute())} - raise ValueError(f"Path exists but is not a directory: {target}") - except (OSError, RuntimeError) as e: - raise ValueError(f"Invalid path: {target} - {e!s}") from e - - if target.startswith("git@") or target.endswith(".git"): - return "repository", {"target_repo": target} - - if "." in target and "/" not in target and not target.startswith("."): - parts = target.split(".") - if len(parts) >= 2 and all(p and p.strip() for p in parts): - return "web_application", {"target_url": f"https://{target}"} - - raise ValueError( - f"Invalid target: {target}\n" - "Target must be one of:\n" - "- A valid URL (http:// or https://)\n" - "- A Git repository URL (https://github.com/... or git@github.com:...)\n" - "- A local directory path\n" - "- A domain name (e.g., example.com)" - ) - - def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Strix Multi-Agent Cybersecurity Penetration Testing Tool", @@ -399,16 +257,23 @@ Examples: # Domain penetration test strix --target example.com + # Multiple targets (e.g., white-box testing with source and deployed app) + strix --target https://github.com/user/repo --target https://example.com + strix --target ./my-project --target https://staging.example.com --target https://prod.example.com + # Custom instructions strix --target example.com --instruction "Focus on authentication vulnerabilities" """, ) parser.add_argument( + "-t", "--target", type=str, required=True, - help="Target to test (URL, repository, local directory path, or domain name)", + action="append", + help="Target to test (URL, repository, local directory path, or domain name). " + "Can be specified multiple times for multi-target scans.", ) parser.add_argument( "--instruction", @@ -439,127 +304,53 @@ Examples: args = parser.parse_args() - try: - args.target_type, args.target_dict = infer_target_type(args.target) - except ValueError as e: - parser.error(str(e)) + args.targets_info = [] + for target in args.target: + try: + target_type, target_dict = infer_target_type(target) + + if target_type == "local_code": + display_target = target_dict.get("target_path", target) + else: + display_target = target + + args.targets_info.append( + {"type": target_type, "details": target_dict, "original": display_target} + ) + except ValueError: + parser.error(f"Invalid target '{target}'") + + assign_workspace_subdirs(args.targets_info) return args -def _get_severity_color(severity: str) -> str: - severity_colors = { - "critical": "#dc2626", - "high": "#ea580c", - "medium": "#d97706", - "low": "#65a30d", - "info": "#0284c7", - } - return severity_colors.get(severity, "#6b7280") - - -def _build_stats_text(tracer: Any) -> Text: - stats_text = Text() - if not tracer: - return stats_text - - vuln_count = len(tracer.vulnerability_reports) - tool_count = tracer.get_real_tool_count() - agent_count = len(tracer.agents) - - if vuln_count > 0: - severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} - for report in tracer.vulnerability_reports: - severity = report.get("severity", "").lower() - if severity in severity_counts: - severity_counts[severity] += 1 - - stats_text.append("🔍 Vulnerabilities Found: ", style="bold red") - - severity_parts = [] - for severity in ["critical", "high", "medium", "low", "info"]: - count = severity_counts[severity] - if count > 0: - severity_color = _get_severity_color(severity) - severity_text = Text() - severity_text.append(f"{severity.upper()}: ", style=severity_color) - severity_text.append(str(count), style=f"bold {severity_color}") - severity_parts.append(severity_text) - - for i, part in enumerate(severity_parts): - stats_text.append(part) - if i < len(severity_parts) - 1: - stats_text.append(" | ", style="dim white") - - stats_text.append(" (Total: ", style="dim white") - stats_text.append(str(vuln_count), style="bold yellow") - stats_text.append(")", style="dim white") - stats_text.append("\n") - else: - stats_text.append("🔍 Vulnerabilities Found: ", style="bold green") - stats_text.append("0", style="bold white") - stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green") - stats_text.append("\n") - - stats_text.append("🤖 Agents Used: ", style="bold cyan") - stats_text.append(str(agent_count), style="bold white") - stats_text.append(" • ", style="dim white") - stats_text.append("🛠️ Tools Called: ", style="bold cyan") - stats_text.append(str(tool_count), style="bold white") - - return stats_text - - -def _build_llm_stats_text(tracer: Any) -> Text: - llm_stats_text = Text() - if not tracer: - return llm_stats_text - - llm_stats = tracer.get_total_llm_stats() - total_stats = llm_stats["total"] - - if total_stats["requests"] > 0: - llm_stats_text.append("📥 Input Tokens: ", style="bold cyan") - llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white") - - if total_stats["cached_tokens"] > 0: - llm_stats_text.append(" • ", style="dim white") - llm_stats_text.append("⚡ Cached: ", style="bold green") - llm_stats_text.append( - format_token_count(total_stats["cached_tokens"]), style="bold green" - ) - - llm_stats_text.append(" • ", style="dim white") - llm_stats_text.append("📤 Output Tokens: ", style="bold cyan") - llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white") - - if total_stats["cost"] > 0: - llm_stats_text.append(" • ", style="dim white") - llm_stats_text.append("💰 Total Cost: $", style="bold cyan") - llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow") - - return llm_stats_text - - def display_completion_message(args: argparse.Namespace, results_path: Path) -> None: console = Console() tracer = get_global_tracer() - target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target - completion_text = Text() completion_text.append("🦉 ", style="bold white") completion_text.append("AGENT FINISHED", style="bold green") completion_text.append(" • ", style="dim white") completion_text.append("Penetration test completed", style="white") - stats_text = _build_stats_text(tracer) + stats_text = build_stats_text(tracer) - llm_stats_text = _build_llm_stats_text(tracer) + llm_stats_text = build_llm_stats_text(tracer) target_text = Text() - target_text.append("🎯 Target: ", style="bold cyan") - target_text.append(str(target_value), style="bold white") + if len(args.targets_info) == 1: + target_text.append("🎯 Target: ", style="bold cyan") + target_text.append(args.targets_info[0]["original"], style="bold white") + else: + target_text.append("🎯 Targets: ", style="bold cyan") + target_text.append(f"{len(args.targets_info)} targets\n", style="bold white") + for i, target_info in enumerate(args.targets_info): + target_text.append(" • ", style="dim white") + target_text.append(target_info["original"], style="white") + if i < len(args.targets_info) - 1: + target_text.append("\n") results_text = Text() results_text.append("📊 Results Saved To: ", style="bold cyan") @@ -575,19 +366,19 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) -> stats_text, "\n", llm_stats_text, - "\n", + "\n\n", results_text, ) else: panel_content = Text.assemble( - completion_text, "\n\n", target_text, "\n", stats_text, "\n", results_text + completion_text, "\n\n", target_text, "\n", stats_text, "\n\n", results_text ) elif llm_stats_text.plain: panel_content = Text.assemble( - completion_text, "\n\n", target_text, "\n", llm_stats_text, "\n", results_text + completion_text, "\n\n", target_text, "\n", llm_stats_text, "\n\n", results_text ) else: - panel_content = Text.assemble(completion_text, "\n\n", target_text, "\n", results_text) + panel_content = Text.assemble(completion_text, "\n\n", target_text, "\n\n", results_text) panel = Panel( panel_content, @@ -602,86 +393,11 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) -> console.print() -def _check_docker_connection() -> Any: - try: - return docker.from_env() - except DockerException: - console = Console() - error_text = Text() - error_text.append("❌ ", style="bold red") - error_text.append("DOCKER NOT AVAILABLE", style="bold red") - error_text.append("\n\n", style="white") - error_text.append("Cannot connect to Docker daemon.\n", style="white") - error_text.append("Please ensure Docker is installed and running.\n\n", style="white") - error_text.append("Try running: ", style="dim white") - error_text.append("sudo systemctl start docker", style="dim cyan") - - panel = Panel( - error_text, - title="[bold red]🛡️ STRIX STARTUP ERROR", - title_align="center", - border_style="red", - padding=(1, 2), - ) - console.print("\n", panel, "\n") - sys.exit(1) - - -def _image_exists(client: Any) -> bool: - try: - client.images.get(STRIX_IMAGE) - except docker.errors.ImageNotFound: - return False - else: - return True - - -def _update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None: - if "Pull complete" in layer_status or "Already exists" in layer_status: - layers_info[layer_id] = "✓" - elif "Downloading" in layer_status: - layers_info[layer_id] = "↓" - elif "Extracting" in layer_status: - layers_info[layer_id] = "📦" - elif "Waiting" in layer_status: - layers_info[layer_id] = "⏳" - else: - layers_info[layer_id] = "•" - - -def _process_pull_line( - line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str -) -> str: - if "id" in line and "status" in line: - layer_id = line["id"] - _update_layer_status(layers_info, layer_id, line["status"]) - - completed = sum(1 for v in layers_info.values() if v == "✓") - total = len(layers_info) - - if total > 0: - update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete" - if update_msg != last_update: - status.update(update_msg) - return update_msg - - elif "status" in line and "id" not in line: - global_status = line["status"] - if "Pulling from" in global_status: - status.update("[bold cyan]Fetching image manifest...") - elif "Digest:" in global_status: - status.update("[bold cyan]Verifying image...") - elif "Status:" in global_status: - status.update("[bold cyan]Finalizing...") - - return last_update - - def pull_docker_image() -> None: console = Console() - client = _check_docker_connection() + client = check_docker_connection() - if _image_exists(client): + if image_exists(client, STRIX_IMAGE): return console.print() @@ -695,7 +411,7 @@ def pull_docker_image() -> None: last_update = "" for line in client.api.pull(STRIX_IMAGE, stream=True, decode=True): - last_update = _process_pull_line(line, layers_info, status, last_update) + last_update = process_pull_line(line, layers_info, status, last_update) except DockerException as e: console.print() @@ -738,11 +454,14 @@ def main() -> None: if not args.run_name: args.run_name = generate_run_name() - if args.target_type == "repository": - repo_url = args.target_dict["target_repo"] - cloned_path = clone_repository(repo_url, args.run_name) + for target_info in args.targets_info: + if target_info["type"] == "repository": + repo_url = target_info["details"]["target_repo"] + dest_name = target_info["details"].get("workspace_subdir") + cloned_path = clone_repository(repo_url, args.run_name, dest_name) + target_info["details"]["cloned_repo_path"] = cloned_path - args.target_dict["cloned_repo_path"] = cloned_path + args.local_sources = collect_local_sources(args.targets_info) if args.non_interactive: asyncio.run(run_cli(args)) diff --git a/strix/interface/tool_components/scan_info_renderer.py b/strix/interface/tool_components/scan_info_renderer.py index 31d852e..602fb80 100644 --- a/strix/interface/tool_components/scan_info_renderer.py +++ b/strix/interface/tool_components/scan_info_renderer.py @@ -16,23 +16,28 @@ class ScanStartInfoRenderer(BaseToolRenderer): args = tool_data.get("args", {}) status = tool_data.get("status", "unknown") - target = args.get("target", {}) + targets = args.get("targets", []) - target_display = cls._build_target_display(target) - - content = f"🚀 Starting scan on {target_display}" + if len(targets) == 1: + target_display = cls._build_single_target_display(targets[0]) + content = f"🚀 Starting penetration test on {target_display}" + elif len(targets) > 1: + content = f"🚀 Starting penetration test on {len(targets)} targets" + for target_info in targets: + target_display = cls._build_single_target_display(target_info) + content += f"\n • {target_display}" + else: + content = "🚀 Starting penetration test" css_classes = cls.get_css_classes(status) return Static(content, classes=css_classes) @classmethod - def _build_target_display(cls, target: dict[str, Any]) -> str: - if target_url := target.get("target_url"): - return cls.escape_markup(str(target_url)) - if target_repo := target.get("target_repo"): - return cls.escape_markup(str(target_repo)) - if target_path := target.get("target_path"): - return cls.escape_markup(str(target_path)) + def _build_single_target_display(cls, target_info: dict[str, Any]) -> str: + original = target_info.get("original") + if original: + return cls.escape_markup(str(original)) + return "unknown target" diff --git a/strix/interface/tui.py b/strix/interface/tui.py index ab0645c..ff0a255 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -312,8 +312,7 @@ class StrixTUIApp(App): # type: ignore[misc] def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]: return { "scan_id": args.run_name, - "scan_type": args.target_type, - "target": args.target_dict, + "targets": args.targets_info, "user_instructions": args.instruction or "", "run_name": args.run_name, } @@ -326,10 +325,8 @@ class StrixTUIApp(App): # type: ignore[misc] "max_iterations": 300, } - if args.target_type == "local_code" and "target_path" in args.target_dict: - config["local_source_path"] = args.target_dict["target_path"] - elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict: - config["local_source_path"] = args.target_dict["cloned_repo_path"] + if getattr(args, "local_sources", None): + config["local_sources"] = args.local_sources return config diff --git a/strix/interface/utils.py b/strix/interface/utils.py new file mode 100644 index 0000000..32eb182 --- /dev/null +++ b/strix/interface/utils.py @@ -0,0 +1,434 @@ +import re +import secrets +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import docker +from docker.errors import DockerException, ImageNotFound +from rich.console import Console +from rich.panel import Panel +from rich.text import Text + + +# Token formatting utilities +def format_token_count(count: float) -> str: + count = int(count) + if count >= 1_000_000: + return f"{count / 1_000_000:.1f}M" + if count >= 1_000: + return f"{count / 1_000:.1f}K" + return str(count) + + +# Display utilities +def get_severity_color(severity: str) -> str: + severity_colors = { + "critical": "#dc2626", + "high": "#ea580c", + "medium": "#d97706", + "low": "#65a30d", + "info": "#0284c7", + } + return severity_colors.get(severity, "#6b7280") + + +def build_stats_text(tracer: Any) -> Text: + stats_text = Text() + if not tracer: + return stats_text + + vuln_count = len(tracer.vulnerability_reports) + tool_count = tracer.get_real_tool_count() + agent_count = len(tracer.agents) + + if vuln_count > 0: + severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} + for report in tracer.vulnerability_reports: + severity = report.get("severity", "").lower() + if severity in severity_counts: + severity_counts[severity] += 1 + + stats_text.append("🔍 Vulnerabilities Found: ", style="bold red") + + severity_parts = [] + for severity in ["critical", "high", "medium", "low", "info"]: + count = severity_counts[severity] + if count > 0: + severity_color = get_severity_color(severity) + severity_text = Text() + severity_text.append(f"{severity.upper()}: ", style=severity_color) + severity_text.append(str(count), style=f"bold {severity_color}") + severity_parts.append(severity_text) + + for i, part in enumerate(severity_parts): + stats_text.append(part) + if i < len(severity_parts) - 1: + stats_text.append(" | ", style="dim white") + + stats_text.append(" (Total: ", style="dim white") + stats_text.append(str(vuln_count), style="bold yellow") + stats_text.append(")", style="dim white") + stats_text.append("\n") + else: + stats_text.append("🔍 Vulnerabilities Found: ", style="bold green") + stats_text.append("0", style="bold white") + stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green") + stats_text.append("\n") + + stats_text.append("🤖 Agents Used: ", style="bold cyan") + stats_text.append(str(agent_count), style="bold white") + stats_text.append(" • ", style="dim white") + stats_text.append("🛠️ Tools Called: ", style="bold cyan") + stats_text.append(str(tool_count), style="bold white") + + return stats_text + + +def build_llm_stats_text(tracer: Any) -> Text: + llm_stats_text = Text() + if not tracer: + return llm_stats_text + + llm_stats = tracer.get_total_llm_stats() + total_stats = llm_stats["total"] + + if total_stats["requests"] > 0: + llm_stats_text.append("📥 Input Tokens: ", style="bold cyan") + llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white") + + if total_stats["cached_tokens"] > 0: + llm_stats_text.append(" • ", style="dim white") + llm_stats_text.append("⚡ Cached: ", style="bold green") + llm_stats_text.append( + format_token_count(total_stats["cached_tokens"]), style="bold green" + ) + + llm_stats_text.append(" • ", style="dim white") + llm_stats_text.append("📤 Output Tokens: ", style="bold cyan") + llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white") + + if total_stats["cost"] > 0: + llm_stats_text.append(" • ", style="dim white") + llm_stats_text.append("💰 Total Cost: $", style="bold cyan") + llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow") + + return llm_stats_text + + +# Name generation utilities +def generate_run_name() -> str: + # fmt: off + adjectives = [ + "stealthy", "sneaky", "crafty", "elite", "phantom", "shadow", "silent", + "rogue", "covert", "ninja", "ghost", "cyber", "digital", "binary", + "encrypted", "obfuscated", "masked", "cloaked", "invisible", "anonymous" + ] + nouns = [ + "exploit", "payload", "backdoor", "rootkit", "keylogger", "botnet", "trojan", + "worm", "virus", "packet", "buffer", "shell", "daemon", "spider", "crawler", + "scanner", "sniffer", "honeypot", "firewall", "breach" + ] + # fmt: on + adj = secrets.choice(adjectives) + noun = secrets.choice(nouns) + number = secrets.randbelow(900) + 100 + return f"{adj}-{noun}-{number}" + + +# Target processing utilities +def infer_target_type(target: str) -> tuple[str, dict[str, str]]: + if not target or not isinstance(target, str): + raise ValueError("Target must be a non-empty string") + + target = target.strip() + + lower_target = target.lower() + bare_repo_prefixes = ( + "github.com/", + "www.github.com/", + "gitlab.com/", + "www.gitlab.com/", + "bitbucket.org/", + "www.bitbucket.org/", + ) + if any(lower_target.startswith(p) for p in bare_repo_prefixes): + return "repository", {"target_repo": f"https://{target}"} + + parsed = urlparse(target) + if parsed.scheme in ("http", "https"): + if any( + host in parsed.netloc.lower() for host in ["github.com", "gitlab.com", "bitbucket.org"] + ): + return "repository", {"target_repo": target} + return "web_application", {"target_url": target} + + path = Path(target).expanduser() + try: + if path.exists(): + if path.is_dir(): + resolved = path.resolve() + return "local_code", {"target_path": str(resolved)} + raise ValueError(f"Path exists but is not a directory: {target}") + except (OSError, RuntimeError) as e: + raise ValueError(f"Invalid path: {target} - {e!s}") from e + + if target.startswith("git@") or target.endswith(".git"): + return "repository", {"target_repo": target} + + if "." in target and "/" not in target and not target.startswith("."): + parts = target.split(".") + if len(parts) >= 2 and all(p and p.strip() for p in parts): + return "web_application", {"target_url": f"https://{target}"} + + raise ValueError( + f"Invalid target: {target}\n" + "Target must be one of:\n" + "- A valid URL (http:// or https://)\n" + "- A Git repository URL (https://github.com/... or git@github.com:...)\n" + "- A local directory path\n" + "- A domain name (e.g., example.com)" + ) + + +def sanitize_name(name: str) -> str: + sanitized = re.sub(r"[^A-Za-z0-9._-]", "-", name.strip()) + return sanitized or "target" + + +def derive_repo_base_name(repo_url: str) -> str: + if repo_url.endswith("/"): + repo_url = repo_url[:-1] + + if ":" in repo_url and repo_url.startswith("git@"): + path_part = repo_url.split(":", 1)[1] + else: + path_part = urlparse(repo_url).path or repo_url + + candidate = path_part.split("/")[-1] + if candidate.endswith(".git"): + candidate = candidate[:-4] + + return sanitize_name(candidate or "repository") + + +def derive_local_base_name(path_str: str) -> str: + try: + base = Path(path_str).resolve().name + except (OSError, RuntimeError): + base = Path(path_str).name + return sanitize_name(base or "workspace") + + +def assign_workspace_subdirs(targets_info: list[dict[str, Any]]) -> None: + name_counts: dict[str, int] = {} + + for target in targets_info: + target_type = target["type"] + details = target["details"] + + base_name: str | None = None + if target_type == "repository": + base_name = derive_repo_base_name(details["target_repo"]) + elif target_type == "local_code": + base_name = derive_local_base_name(details.get("target_path", "local")) + + if base_name is None: + continue + + count = name_counts.get(base_name, 0) + 1 + name_counts[base_name] = count + + workspace_subdir = base_name if count == 1 else f"{base_name}-{count}" + + details["workspace_subdir"] = workspace_subdir + + +def collect_local_sources(targets_info: list[dict[str, Any]]) -> list[dict[str, str]]: + local_sources: list[dict[str, str]] = [] + + for target_info in targets_info: + details = target_info["details"] + workspace_subdir = details.get("workspace_subdir") + + if target_info["type"] == "local_code" and "target_path" in details: + local_sources.append( + { + "source_path": details["target_path"], + "workspace_subdir": workspace_subdir, + } + ) + + elif target_info["type"] == "repository" and "cloned_repo_path" in details: + local_sources.append( + { + "source_path": details["cloned_repo_path"], + "workspace_subdir": workspace_subdir, + } + ) + + return local_sources + + +# Repository utilities +def clone_repository(repo_url: str, run_name: str, dest_name: str | None = None) -> str: + console = Console() + + git_executable = shutil.which("git") + if git_executable is None: + raise FileNotFoundError("Git executable not found in PATH") + + temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name + temp_dir.mkdir(parents=True, exist_ok=True) + + if dest_name: + repo_name = dest_name + else: + repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name + + clone_path = temp_dir / repo_name + + if clone_path.exists(): + shutil.rmtree(clone_path) + + try: + with console.status(f"[bold cyan]Cloning repository {repo_url}...", spinner="dots"): + subprocess.run( # noqa: S603 + [ + git_executable, + "clone", + repo_url, + str(clone_path), + ], + capture_output=True, + text=True, + check=True, + ) + + return str(clone_path.absolute()) + + except subprocess.CalledProcessError as e: + error_text = Text() + error_text.append("❌ ", style="bold red") + error_text.append("REPOSITORY CLONE FAILED", style="bold red") + error_text.append("\n\n", style="white") + error_text.append(f"Could not clone repository: {repo_url}\n", style="white") + error_text.append( + f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red" + ) + + panel = Panel( + error_text, + title="[bold red]🛡️ STRIX CLONE ERROR", + title_align="center", + border_style="red", + padding=(1, 2), + ) + console.print("\n") + console.print(panel) + console.print() + raise + except FileNotFoundError: + error_text = Text() + error_text.append("❌ ", style="bold red") + error_text.append("GIT NOT FOUND", style="bold red") + error_text.append("\n\n", style="white") + error_text.append("Git is not installed or not available in PATH.\n", style="white") + error_text.append("Please install Git to clone repositories.\n", style="white") + + panel = Panel( + error_text, + title="[bold red]🛡️ STRIX CLONE ERROR", + title_align="center", + border_style="red", + padding=(1, 2), + ) + console.print("\n") + console.print(panel) + console.print() + raise + + +# Docker utilities +def check_docker_connection() -> Any: + try: + return docker.from_env() + except DockerException: + console = Console() + error_text = Text() + error_text.append("❌ ", style="bold red") + error_text.append("DOCKER NOT AVAILABLE", style="bold red") + error_text.append("\n\n", style="white") + error_text.append("Cannot connect to Docker daemon.\n", style="white") + error_text.append("Please ensure Docker is installed and running.\n\n", style="white") + error_text.append("Try running: ", style="dim white") + error_text.append("sudo systemctl start docker", style="dim cyan") + + panel = Panel( + error_text, + title="[bold red]🛡️ STRIX STARTUP ERROR", + title_align="center", + border_style="red", + padding=(1, 2), + ) + console.print("\n", panel, "\n") + raise RuntimeError("Docker not available") from None + + +def image_exists(client: Any, image_name: str) -> bool: + try: + client.images.get(image_name) + except ImageNotFound: + return False + else: + return True + + +def update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None: + if "Pull complete" in layer_status or "Already exists" in layer_status: + layers_info[layer_id] = "✓" + elif "Downloading" in layer_status: + layers_info[layer_id] = "↓" + elif "Extracting" in layer_status: + layers_info[layer_id] = "📦" + elif "Waiting" in layer_status: + layers_info[layer_id] = "⏳" + else: + layers_info[layer_id] = "•" + + +def process_pull_line( + line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str +) -> str: + if "id" in line and "status" in line: + layer_id = line["id"] + update_layer_status(layers_info, layer_id, line["status"]) + + completed = sum(1 for v in layers_info.values() if v == "✓") + total = len(layers_info) + + if total > 0: + update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete" + if update_msg != last_update: + status.update(update_msg) + return update_msg + + elif "status" in line and "id" not in line: + global_status = line["status"] + if "Pulling from" in global_status: + status.update("[bold cyan]Fetching image manifest...") + elif "Digest:" in global_status: + status.update("[bold cyan]Verifying image...") + elif "Status:" in global_status: + status.update("[bold cyan]Finalizing...") + + return last_update + + +# LLM utilities +def validate_llm_response(response: Any) -> None: + if not response or not response.choices or not response.choices[0].message.content: + raise RuntimeError("Invalid response from LLM") diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index fa7b5bf..32cc625 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -250,7 +250,9 @@ class DockerRuntime(AbstractRuntime): time.sleep(5) - def _copy_local_directory_to_container(self, container: Container, local_path: str) -> None: + def _copy_local_directory_to_container( + self, container: Container, local_path: str, target_name: str | None = None + ) -> None: import tarfile from io import BytesIO @@ -260,13 +262,20 @@ class DockerRuntime(AbstractRuntime): logger.warning(f"Local path does not exist or is not directory: {local_path_obj}") return - logger.info(f"Copying local directory {local_path_obj} to container") + if target_name: + logger.info( + f"Copying local directory {local_path_obj} to container at " + f"/workspace/{target_name}" + ) + else: + logger.info(f"Copying local directory {local_path_obj} to container") tar_buffer = BytesIO() with tarfile.open(fileobj=tar_buffer, mode="w") as tar: for item in local_path_obj.rglob("*"): if item.is_file(): - arcname = item.relative_to(local_path_obj) + rel_path = item.relative_to(local_path_obj) + arcname = Path(target_name) / rel_path if target_name else rel_path tar.add(item, arcname=arcname) tar_buffer.seek(0) @@ -283,14 +292,26 @@ class DockerRuntime(AbstractRuntime): logger.exception("Failed to copy local directory to container") async def create_sandbox( - self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None + self, + agent_id: str, + existing_token: str | None = None, + local_sources: list[dict[str, str]] | None = None, ) -> SandboxInfo: scan_id = self._get_scan_id(agent_id) container = self._get_or_create_scan_container(scan_id) source_copied_key = f"_source_copied_{scan_id}" - if local_source_path and not hasattr(self, source_copied_key): - self._copy_local_directory_to_container(container, local_source_path) + if local_sources and not hasattr(self, source_copied_key): + for index, source in enumerate(local_sources, start=1): + source_path = source.get("source_path") + if not source_path: + continue + + target_name = source.get("workspace_subdir") + if not target_name: + target_name = Path(source_path).name or f"target_{index}" + + self._copy_local_directory_to_container(container, source_path, target_name) setattr(self, source_copied_key, True) container_id = container.id diff --git a/strix/runtime/runtime.py b/strix/runtime/runtime.py index 328a757..9de1afe 100644 --- a/strix/runtime/runtime.py +++ b/strix/runtime/runtime.py @@ -13,7 +13,10 @@ class SandboxInfo(TypedDict): class AbstractRuntime(ABC): @abstractmethod async def create_sandbox( - self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None + self, + agent_id: str, + existing_token: str | None = None, + local_sources: list[dict[str, str]] | None = None, ) -> SandboxInfo: raise NotImplementedError diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index 20853bd..15a4b42 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -44,8 +44,7 @@ class Tracer: "run_name": self.run_name, "start_time": self.start_time, "end_time": None, - "target": None, - "scan_type": None, + "targets": [], "status": "running", } self._run_dir: Path | None = None @@ -193,8 +192,7 @@ class Tracer: self.scan_config = config self.run_metadata.update( { - "target": config.get("target", {}), - "scan_type": config.get("scan_type", "general"), + "targets": config.get("targets", []), "user_instructions": config.get("user_instructions", ""), "max_iterations": config.get("max_iterations", 200), }