feat(interface): Introduce non-interactive CLI mode and restructure UI layer

This commit is contained in:
Ahmed Allam
2025-10-29 06:14:08 +03:00
committed by Ahmed Allam
parent 85209bfc20
commit 86dd6f5330
29 changed files with 254 additions and 46 deletions

View File

@@ -39,7 +39,7 @@ include = [
]
[tool.poetry.scripts]
strix = "strix.cli.main:main"
strix = "strix.interface.main:main"
[tool.poetry.dependencies]
python = "^3.12"

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from strix.cli.tracer import Tracer
from strix.interface.tracer import Tracer
from jinja2 import (
Environment,
@@ -55,6 +55,7 @@ class BaseAgent(metaclass=AgentMeta):
self.config = config
self.local_source_path = config.get("local_source_path")
self.non_interactive = config.get("non_interactive", False)
if "max_iterations" in config:
self.max_iterations = config["max_iterations"]
@@ -76,7 +77,7 @@ class BaseAgent(metaclass=AgentMeta):
self._current_task: asyncio.Task[Any] | None = None
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
@@ -146,10 +147,10 @@ class BaseAgent(metaclass=AgentMeta):
self._current_task.cancel()
self._current_task = None
async def agent_loop(self, task: str) -> dict[str, Any]:
async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
await self._initialize_sandbox_and_state(task)
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
@@ -161,6 +162,8 @@ class BaseAgent(metaclass=AgentMeta):
continue
if self.state.should_stop():
if self.non_interactive:
return self.state.final_result or {}
await self._enter_waiting_state(tracer)
continue
@@ -173,10 +176,17 @@ class BaseAgent(metaclass=AgentMeta):
try:
should_finish = await self._process_iteration(tracer)
if should_finish:
if self.non_interactive:
self.state.set_completed({"success": True})
if tracer:
tracer.update_agent_status(self.state.agent_id, "completed")
return self.state.final_result or {}
await self._enter_waiting_state(tracer, task_completed=True)
continue
except asyncio.CancelledError:
if self.non_interactive:
raise
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
continue
@@ -200,6 +210,11 @@ class BaseAgent(metaclass=AgentMeta):
except (RuntimeError, ValueError, TypeError) as e:
if not await self._handle_iteration_error(e, tracer):
if self.non_interactive:
self.state.set_completed({"success": False, "error": str(e)})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed")
raise
await self._enter_waiting_state(tracer, error_occurred=True)
continue
@@ -210,7 +225,7 @@ class BaseAgent(metaclass=AgentMeta):
self.state.resume_from_waiting()
self.state.add_message("assistant", "Waiting timeout reached. Resuming execution.")
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
@@ -353,6 +368,8 @@ class BaseAgent(metaclass=AgentMeta):
self.state.set_completed({"success": True})
if tracer:
tracer.update_agent_status(self.state.agent_id, "completed")
if self.non_interactive and self.state.parent_id is None:
return True
return True
return False
@@ -390,7 +407,7 @@ class BaseAgent(metaclass=AgentMeta):
state.resume_from_waiting()
has_new_messages = True
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
@@ -399,7 +416,7 @@ class BaseAgent(metaclass=AgentMeta):
state.resume_from_waiting()
has_new_messages = True
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
@@ -441,7 +458,7 @@ class BaseAgent(metaclass=AgentMeta):
message["read"] = True
if has_new_messages and not state.is_waiting_for_input():
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:

158
strix/interface/cli.py Normal file
View File

@@ -0,0 +1,158 @@
import atexit
import signal
import sys
from typing import Any
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from strix.agents.StrixAgent import StrixAgent
from strix.interface.tracer import Tracer, set_global_tracer
from strix.llm.config import LLMConfig
async def run_cli(args: Any) -> None: # noqa: PLR0915
console = Console()
start_text = Text()
start_text.append("🦉 ", style="bold white")
start_text.append("STRIX CYBERSECURITY AGENT", style="bold green")
target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target
target_text = Text()
target_text.append("🎯 Target: ", style="bold cyan")
target_text.append(str(target_value), style="bold white")
instructions_text = Text()
if args.instruction:
instructions_text.append("📋 Instructions: ", style="bold cyan")
instructions_text.append(args.instruction, style="white")
startup_panel = Panel(
Text.assemble(
start_text,
"\n\n",
target_text,
"\n" if args.instruction else "",
instructions_text if args.instruction else "",
),
title="[bold green]🛡️ STRIX PENETRATION TEST INITIATED",
title_align="center",
border_style="green",
padding=(1, 2),
)
console.print("\n")
console.print(startup_panel)
console.print()
scan_config = {
"scan_id": args.run_name,
"scan_type": args.target_type,
"target": args.target_dict,
"user_instructions": args.instruction or "",
"run_name": args.run_name,
}
llm_config = LLMConfig()
agent_config = {
"llm_config": llm_config,
"max_iterations": 200,
"non_interactive": True,
}
if args.target_type == "local_code" and "target_path" in args.target_dict:
agent_config["local_source_path"] = args.target_dict["target_path"]
elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict:
agent_config["local_source_path"] = args.target_dict["cloned_repo_path"]
tracer = Tracer(args.run_name)
tracer.set_scan_config(scan_config)
def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None:
severity_colors = {
"critical": "#dc2626",
"high": "#ea580c",
"medium": "#d97706",
"low": "#65a30d",
"info": "#0284c7",
}
severity_color = severity_colors.get(severity.lower(), "#6b7280")
vuln_text = Text()
vuln_text.append("🐞 ", style="bold red")
vuln_text.append("VULNERABILITY FOUND", style="bold red")
vuln_text.append("", style="dim white")
vuln_text.append(title, style="bold white")
severity_text = Text()
severity_text.append("Severity: ", style="dim white")
severity_text.append(severity.upper(), style=f"bold {severity_color}")
vuln_panel = Panel(
Text.assemble(
vuln_text,
"\n\n",
severity_text,
"\n\n",
content,
),
title=f"[bold red]🔍 {report_id.upper()}",
title_align="left",
border_style="red",
padding=(1, 2),
)
console.print(vuln_panel)
console.print()
tracer.vulnerability_found_callback = display_vulnerability
def cleanup_on_exit() -> None:
tracer.cleanup()
def signal_handler(_signum: int, _frame: Any) -> None:
console.print("\n[bold yellow]Interrupted! Saving reports...[/bold yellow]")
tracer.cleanup()
sys.exit(0)
atexit.register(cleanup_on_exit)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, signal_handler)
set_global_tracer(tracer)
try:
console.print()
with console.status("[bold cyan]Running penetration test...", spinner="dots") as status:
agent = StrixAgent(agent_config)
await agent.execute_scan(scan_config)
status.stop()
except Exception as e:
console.print(f"[bold red]Error during penetration test:[/] {e}")
raise
if tracer.final_scan_result:
console.print()
final_report_text = Text()
final_report_text.append("📄 ", style="bold cyan")
final_report_text.append("FINAL PENETRATION TEST REPORT", style="bold cyan")
final_report_panel = Panel(
Text.assemble(
final_report_text,
"\n\n",
tracer.final_scan_result,
),
title="[bold cyan]📊 PENETRATION TEST SUMMARY",
title_align="center",
border_style="cyan",
padding=(1, 2),
)
console.print(final_report_panel)
console.print()

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3
"""
Strix Agent Command Line Interface
Strix Agent Interface
"""
import argparse
@@ -23,8 +23,9 @@ from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from strix.cli.app import run_strix_cli
from strix.cli.tracer import get_global_tracer
from strix.interface.cli import run_cli
from strix.interface.tracer import get_global_tracer
from strix.interface.tui import run_tui
from strix.runtime.docker_runtime import STRIX_IMAGE
@@ -381,11 +382,11 @@ def infer_target_type(target: str) -> tuple[str, dict[str, str]]:
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Strix Multi-Agent Cybersecurity Scanner",
description="Strix Multi-Agent Cybersecurity Penetration Testing Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Web application scan
# Web application penetration test
strix --target https://example.com
# GitHub repository analysis
@@ -395,7 +396,7 @@ Examples:
# Local code analysis
strix --target ./my-project
# Domain scan
# Domain penetration test
strix --target example.com
# Custom instructions
@@ -407,12 +408,12 @@ Examples:
"--target",
type=str,
required=True,
help="Target to scan (URL, repository, local directory path, or domain name)",
help="Target to test (URL, repository, local directory path, or domain name)",
)
parser.add_argument(
"--instruction",
type=str,
help="Custom instructions for the scan. This can be "
help="Custom instructions for the penetration test. This can be "
"specific vulnerability types to focus on (e.g., 'Focus on IDOR and XSS'), "
"testing approaches (e.g., 'Perform thorough authentication testing'), "
"test credentials (e.g., 'Use the following credentials to access the app: "
@@ -423,7 +424,17 @@ Examples:
parser.add_argument(
"--run-name",
type=str,
help="Custom name for this scan run",
help="Custom name for this penetration test run",
)
parser.add_argument(
"-n",
"--non-interactive",
action="store_true",
help=(
"Run in non-interactive mode (no TUI, exits on completion). "
"Default is interactive mode with TUI."
),
)
args = parser.parse_args()
@@ -540,7 +551,7 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
completion_text.append("🦉 ", style="bold white")
completion_text.append("AGENT FINISHED", style="bold green")
completion_text.append("", style="dim white")
completion_text.append("Security assessment completed", style="white")
completion_text.append("Penetration test completed", style="white")
stats_text = _build_stats_text(tracer)
@@ -733,7 +744,10 @@ def main() -> None:
args.target_dict["cloned_repo_path"] = cloned_path
asyncio.run(run_strix_cli(args))
if args.non_interactive:
asyncio.run(run_cli(args))
else:
asyncio.run(run_tui(args))
results_path = Path("agent_runs") / args.run_name
display_completion_message(args, results_path)

View File

@@ -1,10 +1,14 @@
import logging
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from uuid import uuid4
if TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__)
_global_tracer: Optional["Tracer"] = None
@@ -48,6 +52,8 @@ class Tracer:
self._next_execution_id = 1
self._next_message_id = 1
self.vulnerability_found_callback: Callable[[str, str, str, str], None] | None = None
def set_run_name(self, run_name: str) -> None:
self.run_name = run_name
self.run_id = run_name
@@ -81,6 +87,12 @@ class Tracer:
self.vulnerability_reports.append(report)
logger.info(f"Added vulnerability report: {report_id} - {title}")
if self.vulnerability_found_callback:
self.vulnerability_found_callback(
report_id, title.strip(), content.strip(), severity.lower().strip()
)
return report_id
def set_final_scan_result(
@@ -194,14 +206,16 @@ class Tracer:
self.end_time = datetime.now(UTC).isoformat()
if self.final_scan_result:
scan_report_file = run_dir / "scan_report.md"
with scan_report_file.open("w", encoding="utf-8") as f:
f.write("# Security Scan Report\n\n")
penetration_test_report_file = run_dir / "penetration_test_report.md"
with penetration_test_report_file.open("w", encoding="utf-8") as f:
f.write("# Security Penetration Test Report\n\n")
f.write(
f"**Generated:** {datetime.now(UTC).strftime('%Y-%m-%d %H:%M:%S UTC')}\n\n"
)
f.write(f"{self.final_scan_result}\n")
logger.info(f"Saved final scan report to: {scan_report_file}")
logger.info(
f"Saved final penetration test report to: {penetration_test_report_file}"
)
if self.vulnerability_reports:
vuln_dir = run_dir / "vulnerabilities"

View File

@@ -31,7 +31,7 @@ from textual.widgets import Button, Label, Static, TextArea, Tree
from textual.widgets.tree import TreeNode
from strix.agents.StrixAgent import StrixAgent
from strix.cli.tracer import Tracer, set_global_tracer
from strix.interface.tracer import Tracer, set_global_tracer
from strix.llm.config import LLMConfig
@@ -49,9 +49,9 @@ def get_package_version() -> str:
class ChatTextArea(TextArea): # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._app_reference: StrixCLIApp | None = None
self._app_reference: StrixTUIApp | None = None
def set_app_reference(self, app: "StrixCLIApp") -> None:
def set_app_reference(self, app: "StrixTUIApp") -> None:
self._app_reference = app
def _on_key(self, event: events.Key) -> None:
@@ -260,8 +260,8 @@ class QuitScreen(ModalScreen): # type: ignore[misc]
self.app.pop_screen()
class StrixCLIApp(App): # type: ignore[misc]
CSS_PATH = "assets/cli.tcss"
class StrixTUIApp(App): # type: ignore[misc]
CSS_PATH = "assets/tui_styles.tcss"
selected_agent_id: reactive[str | None] = reactive(default=None)
show_splash: reactive[bool] = reactive(default=True)
@@ -962,7 +962,7 @@ class StrixCLIApp(App): # type: ignore[misc]
return ""
if role == "user":
from strix.cli.tool_components.user_message_renderer import UserMessageRenderer
from strix.interface.tool_components.user_message_renderer import UserMessageRenderer
return UserMessageRenderer.render_simple(content)
return content
@@ -992,7 +992,7 @@ class StrixCLIApp(App): # type: ignore[misc]
color = tool_colors.get(tool_name, "#737373")
from strix.cli.tool_components.registry import get_tool_renderer
from strix.interface.tool_components.registry import get_tool_renderer
renderer = get_tool_renderer(tool_name)
@@ -1237,6 +1237,7 @@ class StrixCLIApp(App): # type: ignore[misc]
widget.update(plain_text)
async def run_strix_cli(args: argparse.Namespace) -> None:
app = StrixCLIApp(args)
async def run_tui(args: argparse.Namespace) -> None:
"""Run strix in interactive TUI mode with textual."""
app = StrixTUIApp(args)
await app.run_async()

View File

@@ -40,7 +40,7 @@ class DockerRuntime(AbstractRuntime):
def _get_scan_id(self, agent_id: str) -> str:
try:
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer and tracer.scan_config:

View File

@@ -231,12 +231,16 @@ def create_agent(
state = AgentState(task=task, agent_name=name, parent_id=parent_id, max_iterations=200)
llm_config = LLMConfig(prompt_modules=module_list)
agent = StrixAgent(
{
"llm_config": llm_config,
"state": state,
}
)
parent_agent = _agent_instances.get(parent_id)
agent_config = {
"llm_config": llm_config,
"state": state,
}
if parent_agent and hasattr(parent_agent, "non_interactive"):
agent_config["non_interactive"] = parent_agent.non_interactive
agent = StrixAgent(agent_config)
inherited_messages = []
if inherit_context:
@@ -487,7 +491,7 @@ def stop_agent(agent_id: str) -> dict[str, Any]:
agent_node["status"] = "stopping"
try:
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:
@@ -578,7 +582,7 @@ def wait_for_message(
_agent_graph["nodes"][agent_id]["waiting_reason"] = reason
try:
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:

View File

@@ -240,7 +240,7 @@ async def _execute_single_tool(
def _get_tracer_and_agent_id(agent_state: Any | None) -> tuple[Any | None, str]:
try:
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
agent_id = agent_state.agent_id if agent_state else "unknown_agent"

View File

@@ -107,7 +107,7 @@ def _check_active_agents(agent_state: Any = None) -> dict[str, Any] | None:
def _finalize_with_tracer(content: str, success: bool) -> dict[str, Any]:
try:
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer:

View File

@@ -27,7 +27,7 @@ def create_vulnerability_report(
return {"success": False, "message": validation_error}
try:
from strix.cli.tracer import get_global_tracer
from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer: