feat(interface): Introduce non-interactive CLI mode and restructure UI layer

This commit is contained in:
Ahmed Allam
2025-10-29 06:14:08 +03:00
committed by Ahmed Allam
parent 85209bfc20
commit 86dd6f5330
29 changed files with 254 additions and 46 deletions

View File

@@ -39,7 +39,7 @@ include = [
] ]
[tool.poetry.scripts] [tool.poetry.scripts]
strix = "strix.cli.main:main" strix = "strix.interface.main:main"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.12" python = "^3.12"

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from strix.cli.tracer import Tracer from strix.interface.tracer import Tracer
from jinja2 import ( from jinja2 import (
Environment, Environment,
@@ -55,6 +55,7 @@ class BaseAgent(metaclass=AgentMeta):
self.config = config self.config = config
self.local_source_path = config.get("local_source_path") self.local_source_path = config.get("local_source_path")
self.non_interactive = config.get("non_interactive", False)
if "max_iterations" in config: if "max_iterations" in config:
self.max_iterations = config["max_iterations"] self.max_iterations = config["max_iterations"]
@@ -76,7 +77,7 @@ class BaseAgent(metaclass=AgentMeta):
self._current_task: asyncio.Task[Any] | None = None self._current_task: asyncio.Task[Any] | None = None
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer: if tracer:
@@ -146,10 +147,10 @@ class BaseAgent(metaclass=AgentMeta):
self._current_task.cancel() self._current_task.cancel()
self._current_task = None self._current_task = None
async def agent_loop(self, task: str) -> dict[str, Any]: async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
await self._initialize_sandbox_and_state(task) await self._initialize_sandbox_and_state(task)
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
@@ -161,6 +162,8 @@ class BaseAgent(metaclass=AgentMeta):
continue continue
if self.state.should_stop(): if self.state.should_stop():
if self.non_interactive:
return self.state.final_result or {}
await self._enter_waiting_state(tracer) await self._enter_waiting_state(tracer)
continue continue
@@ -173,10 +176,17 @@ class BaseAgent(metaclass=AgentMeta):
try: try:
should_finish = await self._process_iteration(tracer) should_finish = await self._process_iteration(tracer)
if should_finish: if should_finish:
if self.non_interactive:
self.state.set_completed({"success": True})
if tracer:
tracer.update_agent_status(self.state.agent_id, "completed")
return self.state.final_result or {}
await self._enter_waiting_state(tracer, task_completed=True) await self._enter_waiting_state(tracer, task_completed=True)
continue continue
except asyncio.CancelledError: except asyncio.CancelledError:
if self.non_interactive:
raise
await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True) await self._enter_waiting_state(tracer, error_occurred=False, was_cancelled=True)
continue continue
@@ -200,6 +210,11 @@ class BaseAgent(metaclass=AgentMeta):
except (RuntimeError, ValueError, TypeError) as e: except (RuntimeError, ValueError, TypeError) as e:
if not await self._handle_iteration_error(e, tracer): if not await self._handle_iteration_error(e, tracer):
if self.non_interactive:
self.state.set_completed({"success": False, "error": str(e)})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed")
raise
await self._enter_waiting_state(tracer, error_occurred=True) await self._enter_waiting_state(tracer, error_occurred=True)
continue continue
@@ -210,7 +225,7 @@ class BaseAgent(metaclass=AgentMeta):
self.state.resume_from_waiting() self.state.resume_from_waiting()
self.state.add_message("assistant", "Waiting timeout reached. Resuming execution.") self.state.add_message("assistant", "Waiting timeout reached. Resuming execution.")
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer: if tracer:
@@ -353,6 +368,8 @@ class BaseAgent(metaclass=AgentMeta):
self.state.set_completed({"success": True}) self.state.set_completed({"success": True})
if tracer: if tracer:
tracer.update_agent_status(self.state.agent_id, "completed") tracer.update_agent_status(self.state.agent_id, "completed")
if self.non_interactive and self.state.parent_id is None:
return True
return True return True
return False return False
@@ -390,7 +407,7 @@ class BaseAgent(metaclass=AgentMeta):
state.resume_from_waiting() state.resume_from_waiting()
has_new_messages = True has_new_messages = True
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer: if tracer:
@@ -399,7 +416,7 @@ class BaseAgent(metaclass=AgentMeta):
state.resume_from_waiting() state.resume_from_waiting()
has_new_messages = True has_new_messages = True
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer: if tracer:
@@ -441,7 +458,7 @@ class BaseAgent(metaclass=AgentMeta):
message["read"] = True message["read"] = True
if has_new_messages and not state.is_waiting_for_input(): if has_new_messages and not state.is_waiting_for_input():
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer: if tracer:

158
strix/interface/cli.py Normal file
View File

@@ -0,0 +1,158 @@
import atexit
import signal
import sys
from typing import Any
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from strix.agents.StrixAgent import StrixAgent
from strix.interface.tracer import Tracer, set_global_tracer
from strix.llm.config import LLMConfig
async def run_cli(args: Any) -> None: # noqa: PLR0915
console = Console()
start_text = Text()
start_text.append("🦉 ", style="bold white")
start_text.append("STRIX CYBERSECURITY AGENT", style="bold green")
target_value = next(iter(args.target_dict.values())) if args.target_dict else args.target
target_text = Text()
target_text.append("🎯 Target: ", style="bold cyan")
target_text.append(str(target_value), style="bold white")
instructions_text = Text()
if args.instruction:
instructions_text.append("📋 Instructions: ", style="bold cyan")
instructions_text.append(args.instruction, style="white")
startup_panel = Panel(
Text.assemble(
start_text,
"\n\n",
target_text,
"\n" if args.instruction else "",
instructions_text if args.instruction else "",
),
title="[bold green]🛡️ STRIX PENETRATION TEST INITIATED",
title_align="center",
border_style="green",
padding=(1, 2),
)
console.print("\n")
console.print(startup_panel)
console.print()
scan_config = {
"scan_id": args.run_name,
"scan_type": args.target_type,
"target": args.target_dict,
"user_instructions": args.instruction or "",
"run_name": args.run_name,
}
llm_config = LLMConfig()
agent_config = {
"llm_config": llm_config,
"max_iterations": 200,
"non_interactive": True,
}
if args.target_type == "local_code" and "target_path" in args.target_dict:
agent_config["local_source_path"] = args.target_dict["target_path"]
elif args.target_type == "repository" and "cloned_repo_path" in args.target_dict:
agent_config["local_source_path"] = args.target_dict["cloned_repo_path"]
tracer = Tracer(args.run_name)
tracer.set_scan_config(scan_config)
def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None:
severity_colors = {
"critical": "#dc2626",
"high": "#ea580c",
"medium": "#d97706",
"low": "#65a30d",
"info": "#0284c7",
}
severity_color = severity_colors.get(severity.lower(), "#6b7280")
vuln_text = Text()
vuln_text.append("🐞 ", style="bold red")
vuln_text.append("VULNERABILITY FOUND", style="bold red")
vuln_text.append("", style="dim white")
vuln_text.append(title, style="bold white")
severity_text = Text()
severity_text.append("Severity: ", style="dim white")
severity_text.append(severity.upper(), style=f"bold {severity_color}")
vuln_panel = Panel(
Text.assemble(
vuln_text,
"\n\n",
severity_text,
"\n\n",
content,
),
title=f"[bold red]🔍 {report_id.upper()}",
title_align="left",
border_style="red",
padding=(1, 2),
)
console.print(vuln_panel)
console.print()
tracer.vulnerability_found_callback = display_vulnerability
def cleanup_on_exit() -> None:
tracer.cleanup()
def signal_handler(_signum: int, _frame: Any) -> None:
console.print("\n[bold yellow]Interrupted! Saving reports...[/bold yellow]")
tracer.cleanup()
sys.exit(0)
atexit.register(cleanup_on_exit)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, signal_handler)
set_global_tracer(tracer)
try:
console.print()
with console.status("[bold cyan]Running penetration test...", spinner="dots") as status:
agent = StrixAgent(agent_config)
await agent.execute_scan(scan_config)
status.stop()
except Exception as e:
console.print(f"[bold red]Error during penetration test:[/] {e}")
raise
if tracer.final_scan_result:
console.print()
final_report_text = Text()
final_report_text.append("📄 ", style="bold cyan")
final_report_text.append("FINAL PENETRATION TEST REPORT", style="bold cyan")
final_report_panel = Panel(
Text.assemble(
final_report_text,
"\n\n",
tracer.final_scan_result,
),
title="[bold cyan]📊 PENETRATION TEST SUMMARY",
title_align="center",
border_style="cyan",
padding=(1, 2),
)
console.print(final_report_panel)
console.print()

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Strix Agent Command Line Interface Strix Agent Interface
""" """
import argparse import argparse
@@ -23,8 +23,9 @@ from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.text import Text from rich.text import Text
from strix.cli.app import run_strix_cli from strix.interface.cli import run_cli
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
from strix.interface.tui import run_tui
from strix.runtime.docker_runtime import STRIX_IMAGE from strix.runtime.docker_runtime import STRIX_IMAGE
@@ -381,11 +382,11 @@ def infer_target_type(target: str) -> tuple[str, dict[str, str]]:
def parse_arguments() -> argparse.Namespace: def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Strix Multi-Agent Cybersecurity Scanner", description="Strix Multi-Agent Cybersecurity Penetration Testing Tool",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" epilog="""
Examples: Examples:
# Web application scan # Web application penetration test
strix --target https://example.com strix --target https://example.com
# GitHub repository analysis # GitHub repository analysis
@@ -395,7 +396,7 @@ Examples:
# Local code analysis # Local code analysis
strix --target ./my-project strix --target ./my-project
# Domain scan # Domain penetration test
strix --target example.com strix --target example.com
# Custom instructions # Custom instructions
@@ -407,12 +408,12 @@ Examples:
"--target", "--target",
type=str, type=str,
required=True, required=True,
help="Target to scan (URL, repository, local directory path, or domain name)", help="Target to test (URL, repository, local directory path, or domain name)",
) )
parser.add_argument( parser.add_argument(
"--instruction", "--instruction",
type=str, type=str,
help="Custom instructions for the scan. This can be " help="Custom instructions for the penetration test. This can be "
"specific vulnerability types to focus on (e.g., 'Focus on IDOR and XSS'), " "specific vulnerability types to focus on (e.g., 'Focus on IDOR and XSS'), "
"testing approaches (e.g., 'Perform thorough authentication testing'), " "testing approaches (e.g., 'Perform thorough authentication testing'), "
"test credentials (e.g., 'Use the following credentials to access the app: " "test credentials (e.g., 'Use the following credentials to access the app: "
@@ -423,7 +424,17 @@ Examples:
parser.add_argument( parser.add_argument(
"--run-name", "--run-name",
type=str, type=str,
help="Custom name for this scan run", help="Custom name for this penetration test run",
)
parser.add_argument(
"-n",
"--non-interactive",
action="store_true",
help=(
"Run in non-interactive mode (no TUI, exits on completion). "
"Default is interactive mode with TUI."
),
) )
args = parser.parse_args() args = parser.parse_args()
@@ -540,7 +551,7 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) ->
completion_text.append("🦉 ", style="bold white") completion_text.append("🦉 ", style="bold white")
completion_text.append("AGENT FINISHED", style="bold green") completion_text.append("AGENT FINISHED", style="bold green")
completion_text.append("", style="dim white") completion_text.append("", style="dim white")
completion_text.append("Security assessment completed", style="white") completion_text.append("Penetration test completed", style="white")
stats_text = _build_stats_text(tracer) stats_text = _build_stats_text(tracer)
@@ -733,7 +744,10 @@ def main() -> None:
args.target_dict["cloned_repo_path"] = cloned_path args.target_dict["cloned_repo_path"] = cloned_path
asyncio.run(run_strix_cli(args)) if args.non_interactive:
asyncio.run(run_cli(args))
else:
asyncio.run(run_tui(args))
results_path = Path("agent_runs") / args.run_name results_path = Path("agent_runs") / args.run_name
display_completion_message(args, results_path) display_completion_message(args, results_path)

View File

@@ -1,10 +1,14 @@
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional
from uuid import uuid4 from uuid import uuid4
if TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_global_tracer: Optional["Tracer"] = None _global_tracer: Optional["Tracer"] = None
@@ -48,6 +52,8 @@ class Tracer:
self._next_execution_id = 1 self._next_execution_id = 1
self._next_message_id = 1 self._next_message_id = 1
self.vulnerability_found_callback: Callable[[str, str, str, str], None] | None = None
def set_run_name(self, run_name: str) -> None: def set_run_name(self, run_name: str) -> None:
self.run_name = run_name self.run_name = run_name
self.run_id = run_name self.run_id = run_name
@@ -81,6 +87,12 @@ class Tracer:
self.vulnerability_reports.append(report) self.vulnerability_reports.append(report)
logger.info(f"Added vulnerability report: {report_id} - {title}") logger.info(f"Added vulnerability report: {report_id} - {title}")
if self.vulnerability_found_callback:
self.vulnerability_found_callback(
report_id, title.strip(), content.strip(), severity.lower().strip()
)
return report_id return report_id
def set_final_scan_result( def set_final_scan_result(
@@ -194,14 +206,16 @@ class Tracer:
self.end_time = datetime.now(UTC).isoformat() self.end_time = datetime.now(UTC).isoformat()
if self.final_scan_result: if self.final_scan_result:
scan_report_file = run_dir / "scan_report.md" penetration_test_report_file = run_dir / "penetration_test_report.md"
with scan_report_file.open("w", encoding="utf-8") as f: with penetration_test_report_file.open("w", encoding="utf-8") as f:
f.write("# Security Scan Report\n\n") f.write("# Security Penetration Test Report\n\n")
f.write( f.write(
f"**Generated:** {datetime.now(UTC).strftime('%Y-%m-%d %H:%M:%S UTC')}\n\n" f"**Generated:** {datetime.now(UTC).strftime('%Y-%m-%d %H:%M:%S UTC')}\n\n"
) )
f.write(f"{self.final_scan_result}\n") f.write(f"{self.final_scan_result}\n")
logger.info(f"Saved final scan report to: {scan_report_file}") logger.info(
f"Saved final penetration test report to: {penetration_test_report_file}"
)
if self.vulnerability_reports: if self.vulnerability_reports:
vuln_dir = run_dir / "vulnerabilities" vuln_dir = run_dir / "vulnerabilities"

View File

@@ -31,7 +31,7 @@ from textual.widgets import Button, Label, Static, TextArea, Tree
from textual.widgets.tree import TreeNode from textual.widgets.tree import TreeNode
from strix.agents.StrixAgent import StrixAgent from strix.agents.StrixAgent import StrixAgent
from strix.cli.tracer import Tracer, set_global_tracer from strix.interface.tracer import Tracer, set_global_tracer
from strix.llm.config import LLMConfig from strix.llm.config import LLMConfig
@@ -49,9 +49,9 @@ def get_package_version() -> str:
class ChatTextArea(TextArea): # type: ignore[misc] class ChatTextArea(TextArea): # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._app_reference: StrixCLIApp | None = None self._app_reference: StrixTUIApp | None = None
def set_app_reference(self, app: "StrixCLIApp") -> None: def set_app_reference(self, app: "StrixTUIApp") -> None:
self._app_reference = app self._app_reference = app
def _on_key(self, event: events.Key) -> None: def _on_key(self, event: events.Key) -> None:
@@ -260,8 +260,8 @@ class QuitScreen(ModalScreen): # type: ignore[misc]
self.app.pop_screen() self.app.pop_screen()
class StrixCLIApp(App): # type: ignore[misc] class StrixTUIApp(App): # type: ignore[misc]
CSS_PATH = "assets/cli.tcss" CSS_PATH = "assets/tui_styles.tcss"
selected_agent_id: reactive[str | None] = reactive(default=None) selected_agent_id: reactive[str | None] = reactive(default=None)
show_splash: reactive[bool] = reactive(default=True) show_splash: reactive[bool] = reactive(default=True)
@@ -962,7 +962,7 @@ class StrixCLIApp(App): # type: ignore[misc]
return "" return ""
if role == "user": if role == "user":
from strix.cli.tool_components.user_message_renderer import UserMessageRenderer from strix.interface.tool_components.user_message_renderer import UserMessageRenderer
return UserMessageRenderer.render_simple(content) return UserMessageRenderer.render_simple(content)
return content return content
@@ -992,7 +992,7 @@ class StrixCLIApp(App): # type: ignore[misc]
color = tool_colors.get(tool_name, "#737373") color = tool_colors.get(tool_name, "#737373")
from strix.cli.tool_components.registry import get_tool_renderer from strix.interface.tool_components.registry import get_tool_renderer
renderer = get_tool_renderer(tool_name) renderer = get_tool_renderer(tool_name)
@@ -1237,6 +1237,7 @@ class StrixCLIApp(App): # type: ignore[misc]
widget.update(plain_text) widget.update(plain_text)
async def run_strix_cli(args: argparse.Namespace) -> None: async def run_tui(args: argparse.Namespace) -> None:
app = StrixCLIApp(args) """Run strix in interactive TUI mode with textual."""
app = StrixTUIApp(args)
await app.run_async() await app.run_async()

View File

@@ -40,7 +40,7 @@ class DockerRuntime(AbstractRuntime):
def _get_scan_id(self, agent_id: str) -> str: def _get_scan_id(self, agent_id: str) -> str:
try: try:
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer and tracer.scan_config: if tracer and tracer.scan_config:

View File

@@ -231,12 +231,16 @@ def create_agent(
state = AgentState(task=task, agent_name=name, parent_id=parent_id, max_iterations=200) state = AgentState(task=task, agent_name=name, parent_id=parent_id, max_iterations=200)
llm_config = LLMConfig(prompt_modules=module_list) llm_config = LLMConfig(prompt_modules=module_list)
agent = StrixAgent(
{ parent_agent = _agent_instances.get(parent_id)
agent_config = {
"llm_config": llm_config, "llm_config": llm_config,
"state": state, "state": state,
} }
) if parent_agent and hasattr(parent_agent, "non_interactive"):
agent_config["non_interactive"] = parent_agent.non_interactive
agent = StrixAgent(agent_config)
inherited_messages = [] inherited_messages = []
if inherit_context: if inherit_context:
@@ -487,7 +491,7 @@ def stop_agent(agent_id: str) -> dict[str, Any]:
agent_node["status"] = "stopping" agent_node["status"] = "stopping"
try: try:
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer: if tracer:
@@ -578,7 +582,7 @@ def wait_for_message(
_agent_graph["nodes"][agent_id]["waiting_reason"] = reason _agent_graph["nodes"][agent_id]["waiting_reason"] = reason
try: try:
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer: if tracer:

View File

@@ -240,7 +240,7 @@ async def _execute_single_tool(
def _get_tracer_and_agent_id(agent_state: Any | None) -> tuple[Any | None, str]: def _get_tracer_and_agent_id(agent_state: Any | None) -> tuple[Any | None, str]:
try: try:
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
agent_id = agent_state.agent_id if agent_state else "unknown_agent" agent_id = agent_state.agent_id if agent_state else "unknown_agent"

View File

@@ -107,7 +107,7 @@ def _check_active_agents(agent_state: Any = None) -> dict[str, Any] | None:
def _finalize_with_tracer(content: str, success: bool) -> dict[str, Any]: def _finalize_with_tracer(content: str, success: bool) -> dict[str, Any]:
try: try:
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer: if tracer:

View File

@@ -27,7 +27,7 @@ def create_vulnerability_report(
return {"success": False, "message": validation_error} return {"success": False, "message": validation_error}
try: try:
from strix.cli.tracer import get_global_tracer from strix.interface.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer: if tracer: