758 lines
25 KiB
Python
758 lines
25 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Strix Agent Interface
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import secrets
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib.parse import urlparse
|
|
|
|
import docker
|
|
import litellm
|
|
from docker.errors import DockerException
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.text import Text
|
|
|
|
from strix.interface.cli import run_cli
|
|
from strix.interface.tui import run_tui
|
|
from strix.runtime.docker_runtime import STRIX_IMAGE
|
|
from strix.telemetry.tracer import get_global_tracer
|
|
|
|
|
|
logging.getLogger().setLevel(logging.ERROR)
|
|
|
|
|
|
def format_token_count(count: float) -> str:
|
|
count = int(count)
|
|
if count >= 1_000_000:
|
|
return f"{count / 1_000_000:.1f}M"
|
|
if count >= 1_000:
|
|
return f"{count / 1_000:.1f}K"
|
|
return str(count)
|
|
|
|
|
|
def validate_environment() -> None: # noqa: PLR0912, PLR0915
|
|
console = Console()
|
|
missing_required_vars = []
|
|
missing_optional_vars = []
|
|
|
|
if not os.getenv("STRIX_LLM"):
|
|
missing_required_vars.append("STRIX_LLM")
|
|
|
|
has_base_url = any(
|
|
[
|
|
os.getenv("LLM_API_BASE"),
|
|
os.getenv("OPENAI_API_BASE"),
|
|
os.getenv("LITELLM_BASE_URL"),
|
|
os.getenv("OLLAMA_API_BASE"),
|
|
]
|
|
)
|
|
|
|
if not os.getenv("LLM_API_KEY"):
|
|
if not has_base_url:
|
|
missing_required_vars.append("LLM_API_KEY")
|
|
else:
|
|
missing_optional_vars.append("LLM_API_KEY")
|
|
|
|
if not has_base_url:
|
|
missing_optional_vars.append("LLM_API_BASE")
|
|
|
|
if not os.getenv("PERPLEXITY_API_KEY"):
|
|
missing_optional_vars.append("PERPLEXITY_API_KEY")
|
|
|
|
if missing_required_vars:
|
|
error_text = Text()
|
|
error_text.append("❌ ", style="bold red")
|
|
error_text.append("MISSING REQUIRED ENVIRONMENT VARIABLES", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
|
|
for var in missing_required_vars:
|
|
error_text.append(f"• {var}", style="bold yellow")
|
|
error_text.append(" is not set\n", style="white")
|
|
|
|
if missing_optional_vars:
|
|
error_text.append("\nOptional environment variables:\n", style="dim white")
|
|
for var in missing_optional_vars:
|
|
error_text.append(f"• {var}", style="dim yellow")
|
|
error_text.append(" is not set\n", style="dim white")
|
|
|
|
error_text.append("\nRequired environment variables:\n", style="white")
|
|
for var in missing_required_vars:
|
|
if var == "STRIX_LLM":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("STRIX_LLM", style="bold cyan")
|
|
error_text.append(
|
|
" - Model name to use with litellm (e.g., 'openai/gpt-5')\n",
|
|
style="white",
|
|
)
|
|
elif var == "LLM_API_KEY":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("LLM_API_KEY", style="bold cyan")
|
|
error_text.append(
|
|
" - API key for the LLM provider (required for cloud providers)\n",
|
|
style="white",
|
|
)
|
|
|
|
if missing_optional_vars:
|
|
error_text.append("\nOptional environment variables:\n", style="white")
|
|
for var in missing_optional_vars:
|
|
if var == "LLM_API_KEY":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("LLM_API_KEY", style="bold cyan")
|
|
error_text.append(" - API key for the LLM provider\n", style="white")
|
|
elif var == "LLM_API_BASE":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("LLM_API_BASE", style="bold cyan")
|
|
error_text.append(
|
|
" - Custom API base URL if using local models (e.g., Ollama, LMStudio)\n",
|
|
style="white",
|
|
)
|
|
elif var == "PERPLEXITY_API_KEY":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("PERPLEXITY_API_KEY", style="bold cyan")
|
|
error_text.append(
|
|
" - API key for Perplexity AI web search (enables real-time research)\n",
|
|
style="white",
|
|
)
|
|
|
|
error_text.append("\nExample setup:\n", style="white")
|
|
error_text.append("export STRIX_LLM='openai/gpt-5'\n", style="dim white")
|
|
|
|
if "LLM_API_KEY" in missing_required_vars:
|
|
error_text.append("export LLM_API_KEY='your-api-key-here'\n", style="dim white")
|
|
|
|
if missing_optional_vars:
|
|
for var in missing_optional_vars:
|
|
if var == "LLM_API_KEY":
|
|
error_text.append(
|
|
"export LLM_API_KEY='your-api-key-here' # optional with local models\n",
|
|
style="dim white",
|
|
)
|
|
elif var == "LLM_API_BASE":
|
|
error_text.append(
|
|
"export LLM_API_BASE='http://localhost:11434' "
|
|
"# needed for local models only\n",
|
|
style="dim white",
|
|
)
|
|
elif var == "PERPLEXITY_API_KEY":
|
|
error_text.append(
|
|
"export PERPLEXITY_API_KEY='your-perplexity-key-here'\n", style="dim white"
|
|
)
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold red]🛡️ STRIX CONFIGURATION ERROR",
|
|
title_align="center",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
sys.exit(1)
|
|
|
|
|
|
def _validate_llm_response(response: Any) -> None:
|
|
if not response or not response.choices or not response.choices[0].message.content:
|
|
raise RuntimeError("Invalid response from LLM")
|
|
|
|
|
|
def check_docker_installed() -> None:
|
|
if shutil.which("docker") is None:
|
|
console = Console()
|
|
error_text = Text()
|
|
error_text.append("❌ ", style="bold red")
|
|
error_text.append("DOCKER NOT INSTALLED", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append("The 'docker' CLI was not found in your PATH.\n", style="white")
|
|
error_text.append(
|
|
"Please install Docker and ensure the 'docker' command is available.\n\n", style="white"
|
|
)
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold red]🛡️ STRIX STARTUP ERROR",
|
|
title_align="center",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print("\n", panel, "\n")
|
|
sys.exit(1)
|
|
|
|
|
|
async def warm_up_llm() -> None:
|
|
console = Console()
|
|
|
|
try:
|
|
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
|
|
api_key = os.getenv("LLM_API_KEY")
|
|
|
|
if api_key:
|
|
litellm.api_key = api_key
|
|
|
|
api_base = (
|
|
os.getenv("LLM_API_BASE")
|
|
or os.getenv("OPENAI_API_BASE")
|
|
or os.getenv("LITELLM_BASE_URL")
|
|
or os.getenv("OLLAMA_API_BASE")
|
|
)
|
|
if api_base:
|
|
litellm.api_base = api_base
|
|
|
|
test_messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Reply with just 'OK'."},
|
|
]
|
|
|
|
response = litellm.completion(
|
|
model=model_name,
|
|
messages=test_messages,
|
|
)
|
|
|
|
_validate_llm_response(response)
|
|
|
|
except Exception as e: # noqa: BLE001
|
|
error_text = Text()
|
|
error_text.append("❌ ", style="bold red")
|
|
error_text.append("LLM CONNECTION FAILED", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append("Could not establish connection to the language model.\n", style="white")
|
|
error_text.append("Please check your configuration and try again.\n", style="white")
|
|
error_text.append(f"\nError: {e}", style="dim white")
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold red]🛡️ STRIX STARTUP ERROR",
|
|
title_align="center",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
sys.exit(1)
|
|
|
|
|
|
def generate_run_name() -> str:
|
|
# fmt: off
|
|
adjectives = [
|
|
"stealthy", "sneaky", "crafty", "elite", "phantom", "shadow", "silent",
|
|
"rogue", "covert", "ninja", "ghost", "cyber", "digital", "binary",
|
|
"encrypted", "obfuscated", "masked", "cloaked", "invisible", "anonymous"
|
|
]
|
|
nouns = [
|
|
"exploit", "payload", "backdoor", "rootkit", "keylogger", "botnet", "trojan",
|
|
"worm", "virus", "packet", "buffer", "shell", "daemon", "spider", "crawler",
|
|
"scanner", "sniffer", "honeypot", "firewall", "breach"
|
|
]
|
|
# fmt: on
|
|
adj = secrets.choice(adjectives)
|
|
noun = secrets.choice(nouns)
|
|
number = secrets.randbelow(900) + 100
|
|
return f"{adj}-{noun}-{number}"
|
|
|
|
|
|
def clone_repository(repo_url: str, run_name: str) -> str:
|
|
console = Console()
|
|
|
|
git_executable = shutil.which("git")
|
|
if git_executable is None:
|
|
raise FileNotFoundError("Git executable not found in PATH")
|
|
|
|
temp_dir = Path(tempfile.gettempdir()) / "strix_repos" / run_name
|
|
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
repo_name = Path(repo_url).stem if repo_url.endswith(".git") else Path(repo_url).name
|
|
|
|
clone_path = temp_dir / repo_name
|
|
|
|
if clone_path.exists():
|
|
shutil.rmtree(clone_path)
|
|
|
|
try:
|
|
with console.status(f"[bold cyan]Cloning repository {repo_name}...", spinner="dots"):
|
|
subprocess.run( # noqa: S603
|
|
[
|
|
git_executable,
|
|
"clone",
|
|
repo_url,
|
|
str(clone_path),
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
check=True,
|
|
)
|
|
|
|
return str(clone_path.absolute())
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
error_text = Text()
|
|
error_text.append("❌ ", style="bold red")
|
|
error_text.append("REPOSITORY CLONE FAILED", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append(f"Could not clone repository: {repo_url}\n", style="white")
|
|
error_text.append(
|
|
f"Error: {e.stderr if hasattr(e, 'stderr') and e.stderr else str(e)}", style="dim red"
|
|
)
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold red]🛡️ STRIX CLONE ERROR",
|
|
title_align="center",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
sys.exit(1)
|
|
except FileNotFoundError:
|
|
error_text = Text()
|
|
error_text.append("❌ ", style="bold red")
|
|
error_text.append("GIT NOT FOUND", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append("Git is not installed or not available in PATH.\n", style="white")
|
|
error_text.append("Please install Git to clone repositories.\n", style="white")
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold red]🛡️ STRIX CLONE ERROR",
|
|
title_align="center",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
sys.exit(1)
|
|
|
|
|
|
def infer_target_type(target: str) -> tuple[str, dict[str, str]]:
|
|
if not target or not isinstance(target, str):
|
|
raise ValueError("Target must be a non-empty string")
|
|
|
|
target = target.strip()
|
|
|
|
parsed = urlparse(target)
|
|
if parsed.scheme in ("http", "https"):
|
|
if any(
|
|
host in parsed.netloc.lower() for host in ["github.com", "gitlab.com", "bitbucket.org"]
|
|
):
|
|
return "repository", {"target_repo": target}
|
|
return "web_application", {"target_url": target}
|
|
|
|
path = Path(target)
|
|
try:
|
|
if path.exists():
|
|
if path.is_dir():
|
|
return "local_code", {"target_path": str(path.absolute())}
|
|
raise ValueError(f"Path exists but is not a directory: {target}")
|
|
except (OSError, RuntimeError) as e:
|
|
raise ValueError(f"Invalid path: {target} - {e!s}") from e
|
|
|
|
if target.startswith("git@") or target.endswith(".git"):
|
|
return "repository", {"target_repo": target}
|
|
|
|
if "." in target and "/" not in target and not target.startswith("."):
|
|
parts = target.split(".")
|
|
if len(parts) >= 2 and all(p and p.strip() for p in parts):
|
|
return "web_application", {"target_url": f"https://{target}"}
|
|
|
|
raise ValueError(
|
|
f"Invalid target: {target}\n"
|
|
"Target must be one of:\n"
|
|
"- A valid URL (http:// or https://)\n"
|
|
"- A Git repository URL (https://github.com/... or git@github.com:...)\n"
|
|
"- A local directory path\n"
|
|
"- A domain name (e.g., example.com)"
|
|
)
|
|
|
|
|
|
def parse_arguments() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Strix Multi-Agent Cybersecurity Penetration Testing Tool",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Web application penetration test
|
|
strix --target https://example.com
|
|
|
|
# GitHub repository analysis
|
|
strix --target https://github.com/user/repo
|
|
strix --target git@github.com:user/repo.git
|
|
|
|
# Local code analysis
|
|
strix --target ./my-project
|
|
|
|
# Domain penetration test
|
|
strix --target example.com
|
|
|
|
# Custom instructions
|
|
strix --target example.com --instruction "Focus on authentication vulnerabilities"
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--target",
|
|
type=str,
|
|
required=True,
|
|
help="Target to test (URL, repository, local directory path, or domain name)",
|
|
)
|
|
parser.add_argument(
|
|
"--instruction",
|
|
type=str,
|
|
help="Custom instructions for the penetration test. This can be "
|
|
"specific vulnerability types to focus on (e.g., 'Focus on IDOR and XSS'), "
|
|
"testing approaches (e.g., 'Perform thorough authentication testing'), "
|
|
"test credentials (e.g., 'Use the following credentials to access the app: "
|
|
"admin:password123'), "
|
|
"or areas of interest (e.g., 'Check login API endpoint for security issues')",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--run-name",
|
|
type=str,
|
|
help="Custom name for this penetration test run",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-n",
|
|
"--non-interactive",
|
|
action="store_true",
|
|
help=(
|
|
"Run in non-interactive mode (no TUI, exits on completion). "
|
|
"Default is interactive mode with TUI."
|
|
),
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
args.target_type, args.target_dict = infer_target_type(args.target)
|
|
except ValueError as e:
|
|
parser.error(str(e))
|
|
|
|
return args
|
|
|
|
|
|
def _get_severity_color(severity: str) -> str:
|
|
severity_colors = {
|
|
"critical": "#dc2626",
|
|
"high": "#ea580c",
|
|
"medium": "#d97706",
|
|
"low": "#65a30d",
|
|
"info": "#0284c7",
|
|
}
|
|
return severity_colors.get(severity, "#6b7280")
|
|
|
|
|
|
def _build_stats_text(tracer: Any) -> Text:
|
|
stats_text = Text()
|
|
if not tracer:
|
|
return stats_text
|
|
|
|
vuln_count = len(tracer.vulnerability_reports)
|
|
tool_count = tracer.get_real_tool_count()
|
|
agent_count = len(tracer.agents)
|
|
|
|
if vuln_count > 0:
|
|
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
|
for report in tracer.vulnerability_reports:
|
|
severity = report.get("severity", "").lower()
|
|
if severity in severity_counts:
|
|
severity_counts[severity] += 1
|
|
|
|
stats_text.append("🔍 Vulnerabilities Found: ", style="bold red")
|
|
|
|
severity_parts = []
|
|
for severity in ["critical", "high", "medium", "low", "info"]:
|
|
count = severity_counts[severity]
|
|
if count > 0:
|
|
severity_color = _get_severity_color(severity)
|
|
severity_text = Text()
|
|
severity_text.append(f"{severity.upper()}: ", style=severity_color)
|
|
severity_text.append(str(count), style=f"bold {severity_color}")
|
|
severity_parts.append(severity_text)
|
|
|
|
for i, part in enumerate(severity_parts):
|
|
stats_text.append(part)
|
|
if i < len(severity_parts) - 1:
|
|
stats_text.append(" | ", style="dim white")
|
|
|
|
stats_text.append(" (Total: ", style="dim white")
|
|
stats_text.append(str(vuln_count), style="bold yellow")
|
|
stats_text.append(")", style="dim white")
|
|
stats_text.append("\n")
|
|
else:
|
|
stats_text.append("🔍 Vulnerabilities Found: ", style="bold green")
|
|
stats_text.append("0", style="bold white")
|
|
stats_text.append(" (No exploitable vulnerabilities detected)", style="dim green")
|
|
stats_text.append("\n")
|
|
|
|
stats_text.append("🤖 Agents Used: ", style="bold cyan")
|
|
stats_text.append(str(agent_count), style="bold white")
|
|
stats_text.append(" • ", style="dim white")
|
|
stats_text.append("🛠️ Tools Called: ", style="bold cyan")
|
|
stats_text.append(str(tool_count), style="bold white")
|
|
|
|
return stats_text
|
|
|
|
|
|
def _build_llm_stats_text(tracer: Any) -> Text:
|
|
llm_stats_text = Text()
|
|
if not tracer:
|
|
return llm_stats_text
|
|
|
|
llm_stats = tracer.get_total_llm_stats()
|
|
total_stats = llm_stats["total"]
|
|
|
|
if total_stats["requests"] > 0:
|
|
llm_stats_text.append("📥 Input Tokens: ", style="bold cyan")
|
|
llm_stats_text.append(format_token_count(total_stats["input_tokens"]), style="bold white")
|
|
|
|
if total_stats["cached_tokens"] > 0:
|
|
llm_stats_text.append(" • ", style="dim white")
|
|
llm_stats_text.append("⚡ Cached: ", style="bold green")
|
|
llm_stats_text.append(
|
|
format_token_count(total_stats["cached_tokens"]), style="bold green"
|
|
)
|
|
|
|
llm_stats_text.append(" • ", style="dim white")
|
|
llm_stats_text.append("📤 Output Tokens: ", style="bold cyan")
|
|
llm_stats_text.append(format_token_count(total_stats["output_tokens"]), style="bold white")
|
|
|
|
if total_stats["cost"] > 0:
|
|
llm_stats_text.append(" • ", style="dim white")
|
|
llm_stats_text.append("💰 Total Cost: $", style="bold cyan")
|
|
llm_stats_text.append(f"{total_stats['cost']:.4f}", style="bold yellow")
|
|
|
|
return llm_stats_text
|
|
|
|
|
|
def display_completion_message(args: argparse.Namespace, results_path: Path) -> None:
|
|
console = Console()
|
|
tracer = get_global_tracer()
|
|
|
|
target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target
|
|
|
|
completion_text = Text()
|
|
completion_text.append("🦉 ", style="bold white")
|
|
completion_text.append("AGENT FINISHED", style="bold green")
|
|
completion_text.append(" • ", style="dim white")
|
|
completion_text.append("Penetration test completed", style="white")
|
|
|
|
stats_text = _build_stats_text(tracer)
|
|
|
|
llm_stats_text = _build_llm_stats_text(tracer)
|
|
|
|
target_text = Text()
|
|
target_text.append("🎯 Target: ", style="bold cyan")
|
|
target_text.append(str(target_value), style="bold white")
|
|
|
|
results_text = Text()
|
|
results_text.append("📊 Results Saved To: ", style="bold cyan")
|
|
results_text.append(str(results_path), style="bold yellow")
|
|
|
|
if stats_text.plain:
|
|
if llm_stats_text.plain:
|
|
panel_content = Text.assemble(
|
|
completion_text,
|
|
"\n\n",
|
|
target_text,
|
|
"\n",
|
|
stats_text,
|
|
"\n",
|
|
llm_stats_text,
|
|
"\n",
|
|
results_text,
|
|
)
|
|
else:
|
|
panel_content = Text.assemble(
|
|
completion_text, "\n\n", target_text, "\n", stats_text, "\n", results_text
|
|
)
|
|
elif llm_stats_text.plain:
|
|
panel_content = Text.assemble(
|
|
completion_text, "\n\n", target_text, "\n", llm_stats_text, "\n", results_text
|
|
)
|
|
else:
|
|
panel_content = Text.assemble(completion_text, "\n\n", target_text, "\n", results_text)
|
|
|
|
panel = Panel(
|
|
panel_content,
|
|
title="[bold green]🛡️ STRIX CYBERSECURITY AGENT",
|
|
title_align="center",
|
|
border_style="green",
|
|
padding=(1, 2),
|
|
)
|
|
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
|
|
|
|
def _check_docker_connection() -> Any:
|
|
try:
|
|
return docker.from_env()
|
|
except DockerException:
|
|
console = Console()
|
|
error_text = Text()
|
|
error_text.append("❌ ", style="bold red")
|
|
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append("Cannot connect to Docker daemon.\n", style="white")
|
|
error_text.append("Please ensure Docker is installed and running.\n\n", style="white")
|
|
error_text.append("Try running: ", style="dim white")
|
|
error_text.append("sudo systemctl start docker", style="dim cyan")
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold red]🛡️ STRIX STARTUP ERROR",
|
|
title_align="center",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print("\n", panel, "\n")
|
|
sys.exit(1)
|
|
|
|
|
|
def _image_exists(client: Any) -> bool:
|
|
try:
|
|
client.images.get(STRIX_IMAGE)
|
|
except docker.errors.ImageNotFound:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def _update_layer_status(layers_info: dict[str, str], layer_id: str, layer_status: str) -> None:
|
|
if "Pull complete" in layer_status or "Already exists" in layer_status:
|
|
layers_info[layer_id] = "✓"
|
|
elif "Downloading" in layer_status:
|
|
layers_info[layer_id] = "↓"
|
|
elif "Extracting" in layer_status:
|
|
layers_info[layer_id] = "📦"
|
|
elif "Waiting" in layer_status:
|
|
layers_info[layer_id] = "⏳"
|
|
else:
|
|
layers_info[layer_id] = "•"
|
|
|
|
|
|
def _process_pull_line(
|
|
line: dict[str, Any], layers_info: dict[str, str], status: Any, last_update: str
|
|
) -> str:
|
|
if "id" in line and "status" in line:
|
|
layer_id = line["id"]
|
|
_update_layer_status(layers_info, layer_id, line["status"])
|
|
|
|
completed = sum(1 for v in layers_info.values() if v == "✓")
|
|
total = len(layers_info)
|
|
|
|
if total > 0:
|
|
update_msg = f"[bold cyan]Progress: {completed}/{total} layers complete"
|
|
if update_msg != last_update:
|
|
status.update(update_msg)
|
|
return update_msg
|
|
|
|
elif "status" in line and "id" not in line:
|
|
global_status = line["status"]
|
|
if "Pulling from" in global_status:
|
|
status.update("[bold cyan]Fetching image manifest...")
|
|
elif "Digest:" in global_status:
|
|
status.update("[bold cyan]Verifying image...")
|
|
elif "Status:" in global_status:
|
|
status.update("[bold cyan]Finalizing...")
|
|
|
|
return last_update
|
|
|
|
|
|
def pull_docker_image() -> None:
|
|
console = Console()
|
|
client = _check_docker_connection()
|
|
|
|
if _image_exists(client):
|
|
return
|
|
|
|
console.print()
|
|
console.print(f"[bold cyan]🐳 Pulling Docker image:[/] {STRIX_IMAGE}")
|
|
console.print("[dim yellow]This only happens on first run and may take a few minutes...[/]")
|
|
console.print()
|
|
|
|
with console.status("[bold cyan]Downloading image layers...", spinner="dots") as status:
|
|
try:
|
|
layers_info: dict[str, str] = {}
|
|
last_update = ""
|
|
|
|
for line in client.api.pull(STRIX_IMAGE, stream=True, decode=True):
|
|
last_update = _process_pull_line(line, layers_info, status, last_update)
|
|
|
|
except DockerException as e:
|
|
console.print()
|
|
error_text = Text()
|
|
error_text.append("❌ ", style="bold red")
|
|
error_text.append("FAILED TO PULL IMAGE", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append(f"Could not download: {STRIX_IMAGE}\n", style="white")
|
|
error_text.append(str(e), style="dim red")
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold red]🛡️ DOCKER PULL ERROR",
|
|
title_align="center",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print(panel, "\n")
|
|
sys.exit(1)
|
|
|
|
success_text = Text()
|
|
success_text.append("✅ ", style="bold green")
|
|
success_text.append("Successfully pulled Docker image", style="green")
|
|
console.print(success_text)
|
|
console.print()
|
|
|
|
|
|
def main() -> None:
|
|
if sys.platform == "win32":
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
|
args = parse_arguments()
|
|
|
|
check_docker_installed()
|
|
pull_docker_image()
|
|
|
|
validate_environment()
|
|
asyncio.run(warm_up_llm())
|
|
|
|
if not args.run_name:
|
|
args.run_name = generate_run_name()
|
|
|
|
if args.target_type == "repository":
|
|
repo_url = args.target_dict["target_repo"]
|
|
cloned_path = clone_repository(repo_url, args.run_name)
|
|
|
|
args.target_dict["cloned_repo_path"] = cloned_path
|
|
|
|
if args.non_interactive:
|
|
asyncio.run(run_cli(args))
|
|
else:
|
|
asyncio.run(run_tui(args))
|
|
|
|
results_path = Path("agent_runs") / args.run_name
|
|
display_completion_message(args, results_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|