586 lines
19 KiB
Python
586 lines
19 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Strix Agent Interface
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import shutil
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import litellm
|
|
from docker.errors import DockerException
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.text import Text
|
|
|
|
from strix.config import Config, apply_saved_config, save_current_config
|
|
from strix.config.config import resolve_llm_config
|
|
from strix.llm.utils import resolve_strix_model
|
|
|
|
|
|
apply_saved_config()
|
|
|
|
from strix.interface.cli import run_cli # noqa: E402
|
|
from strix.interface.tui import run_tui # noqa: E402
|
|
from strix.interface.utils import ( # noqa: E402
|
|
assign_workspace_subdirs,
|
|
build_final_stats_text,
|
|
check_docker_connection,
|
|
clone_repository,
|
|
collect_local_sources,
|
|
generate_run_name,
|
|
image_exists,
|
|
infer_target_type,
|
|
process_pull_line,
|
|
rewrite_localhost_targets,
|
|
validate_config_file,
|
|
validate_llm_response,
|
|
)
|
|
from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402
|
|
from strix.telemetry import posthog # noqa: E402
|
|
from strix.telemetry.tracer import get_global_tracer # noqa: E402
|
|
|
|
|
|
logging.getLogger().setLevel(logging.ERROR)
|
|
|
|
|
|
def validate_environment() -> None: # noqa: PLR0912, PLR0915
|
|
console = Console()
|
|
missing_required_vars = []
|
|
missing_optional_vars = []
|
|
|
|
strix_llm = Config.get("strix_llm")
|
|
uses_strix_models = strix_llm and strix_llm.startswith("strix/")
|
|
|
|
if not strix_llm:
|
|
missing_required_vars.append("STRIX_LLM")
|
|
|
|
has_base_url = uses_strix_models or any(
|
|
[
|
|
Config.get("llm_api_base"),
|
|
Config.get("openai_api_base"),
|
|
Config.get("litellm_base_url"),
|
|
Config.get("ollama_api_base"),
|
|
]
|
|
)
|
|
|
|
if not Config.get("llm_api_key"):
|
|
missing_optional_vars.append("LLM_API_KEY")
|
|
|
|
if not has_base_url:
|
|
missing_optional_vars.append("LLM_API_BASE")
|
|
|
|
if not Config.get("perplexity_api_key"):
|
|
missing_optional_vars.append("PERPLEXITY_API_KEY")
|
|
|
|
if not Config.get("strix_reasoning_effort"):
|
|
missing_optional_vars.append("STRIX_REASONING_EFFORT")
|
|
|
|
if missing_required_vars:
|
|
error_text = Text()
|
|
error_text.append("MISSING REQUIRED ENVIRONMENT VARIABLES", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
|
|
for var in missing_required_vars:
|
|
error_text.append(f"• {var}", style="bold yellow")
|
|
error_text.append(" is not set\n", style="white")
|
|
|
|
if missing_optional_vars:
|
|
error_text.append("\nOptional environment variables:\n", style="dim white")
|
|
for var in missing_optional_vars:
|
|
error_text.append(f"• {var}", style="dim yellow")
|
|
error_text.append(" is not set\n", style="dim white")
|
|
|
|
error_text.append("\nRequired environment variables:\n", style="white")
|
|
for var in missing_required_vars:
|
|
if var == "STRIX_LLM":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("STRIX_LLM", style="bold cyan")
|
|
error_text.append(
|
|
" - Model name to use with litellm (e.g., 'openai/gpt-5')\n",
|
|
style="white",
|
|
)
|
|
|
|
if missing_optional_vars:
|
|
error_text.append("\nOptional environment variables:\n", style="white")
|
|
for var in missing_optional_vars:
|
|
if var == "LLM_API_KEY":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("LLM_API_KEY", style="bold cyan")
|
|
error_text.append(
|
|
" - API key for the LLM provider "
|
|
"(not needed for local models, Vertex AI, AWS, etc.)\n",
|
|
style="white",
|
|
)
|
|
elif var == "LLM_API_BASE":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("LLM_API_BASE", style="bold cyan")
|
|
error_text.append(
|
|
" - Custom API base URL if using local models (e.g., Ollama, LMStudio)\n",
|
|
style="white",
|
|
)
|
|
elif var == "PERPLEXITY_API_KEY":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("PERPLEXITY_API_KEY", style="bold cyan")
|
|
error_text.append(
|
|
" - API key for Perplexity AI web search (enables real-time research)\n",
|
|
style="white",
|
|
)
|
|
elif var == "STRIX_REASONING_EFFORT":
|
|
error_text.append("• ", style="white")
|
|
error_text.append("STRIX_REASONING_EFFORT", style="bold cyan")
|
|
error_text.append(
|
|
" - Reasoning effort level: none, minimal, low, medium, high, xhigh "
|
|
"(default: high)\n",
|
|
style="white",
|
|
)
|
|
|
|
error_text.append("\nExample setup:\n", style="white")
|
|
if uses_strix_models:
|
|
error_text.append("export STRIX_LLM='strix/gpt-5'\n", style="dim white")
|
|
else:
|
|
error_text.append("export STRIX_LLM='openai/gpt-5'\n", style="dim white")
|
|
|
|
if missing_optional_vars:
|
|
for var in missing_optional_vars:
|
|
if var == "LLM_API_KEY":
|
|
error_text.append(
|
|
"export LLM_API_KEY='your-api-key-here' "
|
|
"# not needed for local models, Vertex AI, AWS, etc.\n",
|
|
style="dim white",
|
|
)
|
|
elif var == "LLM_API_BASE":
|
|
error_text.append(
|
|
"export LLM_API_BASE='http://localhost:11434' "
|
|
"# needed for local models only\n",
|
|
style="dim white",
|
|
)
|
|
elif var == "PERPLEXITY_API_KEY":
|
|
error_text.append(
|
|
"export PERPLEXITY_API_KEY='your-perplexity-key-here'\n", style="dim white"
|
|
)
|
|
elif var == "STRIX_REASONING_EFFORT":
|
|
error_text.append(
|
|
"export STRIX_REASONING_EFFORT='high'\n",
|
|
style="dim white",
|
|
)
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold white]STRIX",
|
|
title_align="left",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
sys.exit(1)
|
|
|
|
|
|
def check_docker_installed() -> None:
|
|
if shutil.which("docker") is None:
|
|
console = Console()
|
|
error_text = Text()
|
|
error_text.append("DOCKER NOT INSTALLED", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append("The 'docker' CLI was not found in your PATH.\n", style="white")
|
|
error_text.append(
|
|
"Please install Docker and ensure the 'docker' command is available.\n\n", style="white"
|
|
)
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold white]STRIX",
|
|
title_align="left",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print("\n", panel, "\n")
|
|
sys.exit(1)
|
|
|
|
|
|
async def warm_up_llm() -> None:
|
|
console = Console()
|
|
|
|
try:
|
|
model_name, api_key, api_base = resolve_llm_config()
|
|
litellm_model, _ = resolve_strix_model(model_name)
|
|
litellm_model = litellm_model or model_name
|
|
|
|
test_messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Reply with just 'OK'."},
|
|
]
|
|
|
|
llm_timeout = int(Config.get("llm_timeout") or "300")
|
|
|
|
completion_kwargs: dict[str, Any] = {
|
|
"model": litellm_model,
|
|
"messages": test_messages,
|
|
"timeout": llm_timeout,
|
|
}
|
|
if api_key:
|
|
completion_kwargs["api_key"] = api_key
|
|
if api_base:
|
|
completion_kwargs["api_base"] = api_base
|
|
|
|
response = litellm.completion(**completion_kwargs)
|
|
|
|
validate_llm_response(response)
|
|
|
|
except Exception as e: # noqa: BLE001
|
|
error_text = Text()
|
|
error_text.append("LLM CONNECTION FAILED", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append("Could not establish connection to the language model.\n", style="white")
|
|
error_text.append("Please check your configuration and try again.\n", style="white")
|
|
error_text.append(f"\nError: {e}", style="dim white")
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold white]STRIX",
|
|
title_align="left",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
sys.exit(1)
|
|
|
|
|
|
def get_version() -> str:
|
|
try:
|
|
from importlib.metadata import version
|
|
|
|
return version("strix-agent")
|
|
except Exception: # noqa: BLE001
|
|
return "unknown"
|
|
|
|
|
|
def parse_arguments() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Strix Multi-Agent Cybersecurity Penetration Testing Tool",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Web application penetration test
|
|
strix --target https://example.com
|
|
|
|
# GitHub repository analysis
|
|
strix --target https://github.com/user/repo
|
|
strix --target git@github.com:user/repo.git
|
|
|
|
# Local code analysis
|
|
strix --target ./my-project
|
|
|
|
# Domain penetration test
|
|
strix --target example.com
|
|
|
|
# IP address penetration test
|
|
strix --target 192.168.1.42
|
|
|
|
# Multiple targets (e.g., white-box testing with source and deployed app)
|
|
strix --target https://github.com/user/repo --target https://example.com
|
|
strix --target ./my-project --target https://staging.example.com --target https://prod.example.com
|
|
|
|
# Custom instructions (inline)
|
|
strix --target example.com --instruction "Focus on authentication vulnerabilities"
|
|
|
|
# Custom instructions (from file)
|
|
strix --target example.com --instruction-file ./instructions.txt
|
|
strix --target https://app.com --instruction-file /path/to/detailed_instructions.md
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-v",
|
|
"--version",
|
|
action="version",
|
|
version=f"strix {get_version()}",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-t",
|
|
"--target",
|
|
type=str,
|
|
required=True,
|
|
action="append",
|
|
help="Target to test (URL, repository, local directory path, domain name, or IP address). "
|
|
"Can be specified multiple times for multi-target scans.",
|
|
)
|
|
parser.add_argument(
|
|
"--instruction",
|
|
type=str,
|
|
help="Custom instructions for the penetration test. This can be "
|
|
"specific vulnerability types to focus on (e.g., 'Focus on IDOR and XSS'), "
|
|
"testing approaches (e.g., 'Perform thorough authentication testing'), "
|
|
"test credentials (e.g., 'Use the following credentials to access the app: "
|
|
"admin:password123'), "
|
|
"or areas of interest (e.g., 'Check login API endpoint for security issues').",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--instruction-file",
|
|
type=str,
|
|
help="Path to a file containing detailed custom instructions for the penetration test. "
|
|
"Use this option when you have lengthy or complex instructions saved in a file "
|
|
"(e.g., '--instruction-file ./detailed_instructions.txt').",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-n",
|
|
"--non-interactive",
|
|
action="store_true",
|
|
help=(
|
|
"Run in non-interactive mode (no TUI, exits on completion). "
|
|
"Default is interactive mode with TUI."
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-m",
|
|
"--scan-mode",
|
|
type=str,
|
|
choices=["quick", "standard", "deep"],
|
|
default="deep",
|
|
help=(
|
|
"Scan mode: "
|
|
"'quick' for fast CI/CD checks, "
|
|
"'standard' for routine testing, "
|
|
"'deep' for thorough security reviews (default). "
|
|
"Default: deep."
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
help="Path to a custom config file (JSON) to use instead of ~/.strix/cli-config.json",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.instruction and args.instruction_file:
|
|
parser.error(
|
|
"Cannot specify both --instruction and --instruction-file. Use one or the other."
|
|
)
|
|
|
|
if args.instruction_file:
|
|
instruction_path = Path(args.instruction_file)
|
|
try:
|
|
with instruction_path.open(encoding="utf-8") as f:
|
|
args.instruction = f.read().strip()
|
|
if not args.instruction:
|
|
parser.error(f"Instruction file '{instruction_path}' is empty")
|
|
except Exception as e: # noqa: BLE001
|
|
parser.error(f"Failed to read instruction file '{instruction_path}': {e}")
|
|
|
|
args.targets_info = []
|
|
for target in args.target:
|
|
try:
|
|
target_type, target_dict = infer_target_type(target)
|
|
|
|
if target_type == "local_code":
|
|
display_target = target_dict.get("target_path", target)
|
|
else:
|
|
display_target = target
|
|
|
|
args.targets_info.append(
|
|
{"type": target_type, "details": target_dict, "original": display_target}
|
|
)
|
|
except ValueError:
|
|
parser.error(f"Invalid target '{target}'")
|
|
|
|
assign_workspace_subdirs(args.targets_info)
|
|
rewrite_localhost_targets(args.targets_info, HOST_GATEWAY_HOSTNAME)
|
|
|
|
return args
|
|
|
|
|
|
def display_completion_message(args: argparse.Namespace, results_path: Path) -> None:
|
|
console = Console()
|
|
tracer = get_global_tracer()
|
|
|
|
scan_completed = False
|
|
if tracer and tracer.scan_results:
|
|
scan_completed = tracer.scan_results.get("scan_completed", False)
|
|
|
|
completion_text = Text()
|
|
if scan_completed:
|
|
completion_text.append("Penetration test completed", style="bold #22c55e")
|
|
else:
|
|
completion_text.append("SESSION ENDED", style="bold #eab308")
|
|
|
|
target_text = Text()
|
|
target_text.append("Target", style="dim")
|
|
target_text.append(" ")
|
|
if len(args.targets_info) == 1:
|
|
target_text.append(args.targets_info[0]["original"], style="bold white")
|
|
else:
|
|
target_text.append(f"{len(args.targets_info)} targets", style="bold white")
|
|
for target_info in args.targets_info:
|
|
target_text.append("\n ")
|
|
target_text.append(target_info["original"], style="white")
|
|
|
|
stats_text = build_final_stats_text(tracer)
|
|
|
|
panel_parts = [completion_text, "\n\n", target_text]
|
|
|
|
if stats_text.plain:
|
|
panel_parts.extend(["\n", stats_text])
|
|
|
|
results_text = Text()
|
|
results_text.append("\n")
|
|
results_text.append("Output", style="dim")
|
|
results_text.append(" ")
|
|
results_text.append(str(results_path), style="#60a5fa")
|
|
panel_parts.extend(["\n", results_text])
|
|
|
|
panel_content = Text.assemble(*panel_parts)
|
|
|
|
border_style = "#22c55e" if scan_completed else "#eab308"
|
|
|
|
panel = Panel(
|
|
panel_content,
|
|
title="[bold white]STRIX",
|
|
title_align="left",
|
|
border_style=border_style,
|
|
padding=(1, 2),
|
|
)
|
|
|
|
console.print("\n")
|
|
console.print(panel)
|
|
console.print()
|
|
console.print("[#60a5fa]models.strix.ai[/] [dim]·[/] [#60a5fa]discord.gg/strix-ai[/]")
|
|
console.print()
|
|
|
|
|
|
def pull_docker_image() -> None:
|
|
console = Console()
|
|
client = check_docker_connection()
|
|
|
|
if image_exists(client, Config.get("strix_image")): # type: ignore[arg-type]
|
|
return
|
|
|
|
console.print()
|
|
console.print(f"[dim]Pulling image[/] {Config.get('strix_image')}")
|
|
console.print("[dim yellow]This only happens on first run and may take a few minutes...[/]")
|
|
console.print()
|
|
|
|
with console.status("[bold cyan]Downloading image layers...", spinner="dots") as status:
|
|
try:
|
|
layers_info: dict[str, str] = {}
|
|
last_update = ""
|
|
|
|
for line in client.api.pull(Config.get("strix_image"), stream=True, decode=True):
|
|
last_update = process_pull_line(line, layers_info, status, last_update)
|
|
|
|
except DockerException as e:
|
|
console.print()
|
|
error_text = Text()
|
|
error_text.append("FAILED TO PULL IMAGE", style="bold red")
|
|
error_text.append("\n\n", style="white")
|
|
error_text.append(f"Could not download: {Config.get('strix_image')}\n", style="white")
|
|
error_text.append(str(e), style="dim red")
|
|
|
|
panel = Panel(
|
|
error_text,
|
|
title="[bold white]STRIX",
|
|
title_align="left",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
console.print(panel, "\n")
|
|
sys.exit(1)
|
|
|
|
success_text = Text()
|
|
success_text.append("Docker image ready", style="#22c55e")
|
|
console.print(success_text)
|
|
console.print()
|
|
|
|
|
|
def apply_config_override(config_path: str) -> None:
|
|
Config._config_file_override = validate_config_file(config_path)
|
|
apply_saved_config(force=True)
|
|
|
|
|
|
def persist_config() -> None:
|
|
if Config._config_file_override is None:
|
|
save_current_config()
|
|
|
|
|
|
def main() -> None:
|
|
if sys.platform == "win32":
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
|
args = parse_arguments()
|
|
|
|
if args.config:
|
|
apply_config_override(args.config)
|
|
|
|
check_docker_installed()
|
|
pull_docker_image()
|
|
|
|
validate_environment()
|
|
asyncio.run(warm_up_llm())
|
|
|
|
persist_config()
|
|
|
|
args.run_name = generate_run_name(args.targets_info)
|
|
|
|
for target_info in args.targets_info:
|
|
if target_info["type"] == "repository":
|
|
repo_url = target_info["details"]["target_repo"]
|
|
dest_name = target_info["details"].get("workspace_subdir")
|
|
cloned_path = clone_repository(repo_url, args.run_name, dest_name)
|
|
target_info["details"]["cloned_repo_path"] = cloned_path
|
|
|
|
args.local_sources = collect_local_sources(args.targets_info)
|
|
|
|
is_whitebox = bool(args.local_sources)
|
|
|
|
posthog.start(
|
|
model=Config.get("strix_llm"),
|
|
scan_mode=args.scan_mode,
|
|
is_whitebox=is_whitebox,
|
|
interactive=not args.non_interactive,
|
|
has_instructions=bool(args.instruction),
|
|
)
|
|
|
|
exit_reason = "user_exit"
|
|
try:
|
|
if args.non_interactive:
|
|
asyncio.run(run_cli(args))
|
|
else:
|
|
asyncio.run(run_tui(args))
|
|
except KeyboardInterrupt:
|
|
exit_reason = "interrupted"
|
|
except Exception as e:
|
|
exit_reason = "error"
|
|
posthog.error("unhandled_exception", str(e))
|
|
raise
|
|
finally:
|
|
tracer = get_global_tracer()
|
|
if tracer:
|
|
posthog.end(tracer, exit_reason=exit_reason)
|
|
|
|
results_path = Path("strix_runs") / args.run_name
|
|
display_completion_message(args, results_path)
|
|
|
|
if args.non_interactive:
|
|
tracer = get_global_tracer()
|
|
if tracer and tracer.vulnerability_reports:
|
|
sys.exit(2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|