diff --git a/containers/Dockerfile b/containers/Dockerfile index b8cbdeb..d11d819 100644 --- a/containers/Dockerfile +++ b/containers/Dockerfile @@ -153,7 +153,7 @@ ENV PYTHONPATH=/app ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt -RUN mkdir -p /shared_workspace /workspace && chown -R pentester:pentester /shared_workspace /workspace /app +RUN mkdir -p /workspace && chown -R pentester:pentester /workspace /app COPY pyproject.toml poetry.lock ./ diff --git a/containers/docker-entrypoint.sh b/containers/docker-entrypoint.sh index 10eea7a..c4178cd 100644 --- a/containers/docker-entrypoint.sh +++ b/containers/docker-entrypoint.sh @@ -1,8 +1,8 @@ #!/bin/bash set -e -if [ -z "$CAIDO_PORT" ] || [ -z "$STRIX_TOOL_SERVER_PORT" ]; then - echo "Error: CAIDO_PORT and STRIX_TOOL_SERVER_PORT must be set." +if [ -z "$CAIDO_PORT" ]; then + echo "Error: CAIDO_PORT must be set." exit 1 fi @@ -114,14 +114,8 @@ sudo -u pentester certutil -N -d sql:/home/pentester/.pki/nssdb --empty-password sudo -u pentester certutil -A -n "Testing Root CA" -t "C,," -i /app/certs/ca.crt -d sql:/home/pentester/.pki/nssdb echo "✅ CA added to browser trust store" -echo "Starting tool server..." -cd /app && \ -STRIX_SANDBOX_MODE=true \ -STRIX_SANDBOX_TOKEN=${STRIX_SANDBOX_TOKEN} \ -CAIDO_API_TOKEN=${TOKEN} \ -poetry run uvicorn strix.runtime.tool_server:app --host 0.0.0.0 --port ${STRIX_TOOL_SERVER_PORT} & - -echo "✅ Tool server started in background" +echo "Container initialization complete - agents will start their own tool servers as needed" +echo "✅ Shared container ready for multi-agent use" cd /workspace diff --git a/pyproject.toml b/pyproject.toml index 2de982a..ddb65e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "strix-agent" -version = "0.1.9" +version = "0.1.10" description = "Open-source AI Hackers for your apps" authors = ["Strix "] readme = "README.md" diff --git a/strix/agents/StrixAgent/strix_agent.py b/strix/agents/StrixAgent/strix_agent.py index 2616091..f83d857 100644 --- a/strix/agents/StrixAgent/strix_agent.py +++ b/strix/agents/StrixAgent/strix_agent.py @@ -30,12 +30,12 @@ class StrixAgent(BaseAgent): cloned_path = target.get("cloned_repo_path") if cloned_path: - shared_workspace_path = "/shared_workspace" + workspace_path = "/workspace" task_parts.append( f"Perform a security assessment of the Git repository: {repo_url}. " f"The repository has been cloned from '{repo_url}' to '{cloned_path}' " - f"(host path) and then copied to '{shared_workspace_path}' in your environment." - f"Analyze the codebase at: {shared_workspace_path}" + f"(host path) and then copied to '{workspace_path}' in your environment." + f"Analyze the codebase at: {workspace_path}" ) else: task_parts.append( @@ -49,12 +49,12 @@ class StrixAgent(BaseAgent): elif scan_type == "local_code": original_path = target.get("target_path", "unknown") - shared_workspace_path = "/shared_workspace" + workspace_path = "/workspace" task_parts.append( f"Perform a security assessment of the local codebase. " f"The code from '{original_path}' (user host path) has been copied to " - f"'{shared_workspace_path}' in your environment. " - f"Analyze the codebase at: {shared_workspace_path}" + f"'{workspace_path}' in your environment. " + f"Analyze the codebase at: {workspace_path}" ) else: diff --git a/strix/agents/StrixAgent/system_prompt.jinja b/strix/agents/StrixAgent/system_prompt.jinja index 3fa206b..3edf209 100644 --- a/strix/agents/StrixAgent/system_prompt.jinja +++ b/strix/agents/StrixAgent/system_prompt.jinja @@ -145,11 +145,10 @@ Remember: A single high-impact vulnerability is worth more than dozens of low-se AGENT ISOLATION & SANDBOXING: -- Each subagent runs in a completely isolated sandbox environment -- Each agent has its own: browser sessions, terminal sessions, proxy (history and scope rules), /workspace directory, environment variables, running processes -- Agents cannot share network ports or interfere with each other's processes -- Only shared resource is /shared_workspace for collaboration and file exchange -- Use /shared_workspace to pass files, reports, and coordination data between agents +- All agents run in the same shared Docker container for efficiency +- Each agent has its own: browser sessions, terminal sessions +- All agents share the same /workspace directory and proxy history +- Agents can see each other's files and proxy traffic for better collaboration SIMPLE WORKFLOW RULES: @@ -312,8 +311,7 @@ PROGRAMMING: - You can install any additional tools/packages needed based on the task/context using package managers (apt, pip, npm, go install, etc.) Directories: -- /workspace - Your private agent directory -- /shared_workspace - Shared between agents +- /workspace - where you should work. - /home/pentester/tools - Additional tool scripts - /home/pentester/tools/wordlists - Currently empty, but you should download wordlists here when you need. diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 07d2cd1..e769788 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -239,6 +239,9 @@ class BaseAgent(metaclass=AgentMeta): self.state.sandbox_token = sandbox_info["auth_token"] self.state.sandbox_info = sandbox_info + if "agent_id" in sandbox_info: + self.state.sandbox_info["agent_id"] = sandbox_info["agent_id"] + if not self.state.task: self.state.task = task diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index 355471d..4b514ec 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -7,19 +7,15 @@ from pathlib import Path from typing import cast import docker -from docker.errors import DockerException, NotFound +from docker.errors import DockerException, ImageNotFound, NotFound from docker.models.containers import Container from .runtime import AbstractRuntime, SandboxInfo -STRIX_AGENT_LABEL = "StrixAgent_ID" -STRIX_SCAN_LABEL = "StrixScan_ID" -STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.4") +STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10") logger = logging.getLogger(__name__) -_initialized_volumes: set[str] = set() - class DockerRuntime(AbstractRuntime): def __init__(self) -> None: @@ -29,9 +25,18 @@ class DockerRuntime(AbstractRuntime): logger.exception("Failed to connect to Docker daemon") raise RuntimeError("Docker is not available or not configured correctly.") from e + self._scan_container: Container | None = None + self._tool_server_port: int | None = None + self._tool_server_token: str | None = None + def _generate_sandbox_token(self) -> str: return secrets.token_urlsafe(32) + def _find_available_port(self) -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return cast("int", s.getsockname()[1]) + def _get_scan_id(self, agent_id: str) -> str: try: from strix.cli.tracer import get_global_tracer @@ -46,37 +51,151 @@ class DockerRuntime(AbstractRuntime): return f"scan-{agent_id.split('-')[0]}" - def _find_available_port(self) -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return cast("int", s.getsockname()[1]) + def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None: + def _validate_image(image: docker.models.images.Image) -> None: + if not image.id or not image.attrs: + raise ImageNotFound(f"Image {image_name} metadata incomplete") - def _get_workspace_volume_name(self, scan_id: str) -> str: - return f"strix-workspace-{scan_id}" + for attempt in range(max_retries): + try: + image = self.client.images.get(image_name) + _validate_image(image) + except ImageNotFound: + if attempt == max_retries - 1: + logger.exception(f"Image {image_name} not found after {max_retries} attempts") + raise + logger.warning(f"Image {image_name} not ready, attempt {attempt + 1}/{max_retries}") + time.sleep(2**attempt) + except DockerException: + if attempt == max_retries - 1: + logger.exception(f"Failed to verify image {image_name}") + raise + logger.warning(f"Docker error verifying image, attempt {attempt + 1}/{max_retries}") + time.sleep(2**attempt) + else: + logger.debug(f"Image {image_name} verified as available") + return - def _get_sandbox_by_agent_id(self, agent_id: str) -> Container | None: - try: - containers = self.client.containers.list( - filters={"label": f"{STRIX_AGENT_LABEL}={agent_id}"} - ) - if not containers: - return None - if len(containers) > 1: - logger.warning( - "Multiple sandboxes found for agent ID %s, using the first one.", agent_id + def _create_container_with_retry(self, scan_id: str, max_retries: int = 3) -> Container: + last_exception = None + + for attempt in range(max_retries): + try: + self._verify_image_available(STRIX_IMAGE) + + caido_port = self._find_available_port() + tool_server_port = self._find_available_port() + tool_server_token = self._generate_sandbox_token() + + self._tool_server_port = tool_server_port + self._tool_server_token = tool_server_token + + container = self.client.containers.run( + STRIX_IMAGE, + command="sleep infinity", + detach=True, + name=f"strix-scan-{scan_id}", + hostname=f"strix-scan-{scan_id}", + ports={ + f"{caido_port}/tcp": caido_port, + f"{tool_server_port}/tcp": tool_server_port, + }, + cap_add=["NET_ADMIN", "NET_RAW"], + labels={"strix-scan-id": scan_id}, + environment={ + "PYTHONUNBUFFERED": "1", + "CAIDO_PORT": str(caido_port), + "TOOL_SERVER_PORT": str(tool_server_port), + "TOOL_SERVER_TOKEN": tool_server_token, + }, + tty=True, ) - return cast("Container", containers[0]) - except DockerException as e: - logger.warning("Failed to get sandbox by agent ID %s: %s", agent_id, e) - return None - def _ensure_workspace_volume(self, volume_name: str) -> None: + self._scan_container = container + logger.info("Created container %s for scan %s", container.id, scan_id) + + self._initialize_container( + container, caido_port, tool_server_port, tool_server_token + ) + except DockerException as e: + last_exception = e + if attempt == max_retries - 1: + logger.exception(f"Failed to create container after {max_retries} attempts") + break + + logger.warning(f"Container creation attempt {attempt + 1}/{max_retries} failed") + + self._tool_server_port = None + self._tool_server_token = None + + sleep_time = (2**attempt) + (0.1 * attempt) + time.sleep(sleep_time) + else: + return container + + raise RuntimeError( + f"Failed to create Docker container after {max_retries} attempts: {last_exception}" + ) from last_exception + + def _get_or_create_scan_container(self, scan_id: str) -> Container: + if self._scan_container: + try: + self._scan_container.reload() + if self._scan_container.status == "running": + return self._scan_container + except NotFound: + self._scan_container = None + self._tool_server_port = None + self._tool_server_token = None + try: - self.client.volumes.get(volume_name) - logger.info(f"Using existing workspace volume: {volume_name}") - except NotFound: - self.client.volumes.create(name=volume_name, driver="local") - logger.info(f"Created new workspace volume: {volume_name}") + containers = self.client.containers.list(filters={"label": f"strix-scan-id={scan_id}"}) + if containers: + container = cast("Container", containers[0]) + if container.status != "running": + container.start() + time.sleep(2) + self._scan_container = container + + for env_var in container.attrs["Config"]["Env"]: + if env_var.startswith("TOOL_SERVER_PORT="): + self._tool_server_port = int(env_var.split("=")[1]) + elif env_var.startswith("TOOL_SERVER_TOKEN="): + self._tool_server_token = env_var.split("=")[1] + + return container + except DockerException as e: + logger.warning("Failed to find existing container for scan %s: %s", scan_id, e) + + logger.info("Creating new Docker container for scan %s", scan_id) + return self._create_container_with_retry(scan_id) + + def _initialize_container( + self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str + ) -> None: + logger.info("Initializing Caido proxy on port %s", caido_port) + result = container.exec_run( + f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'", + detach=False, + ) + + time.sleep(5) + + result = container.exec_run( + "bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester" + ) + caido_token = result.output.decode().strip() if result.exit_code == 0 else "" + + container.exec_run( + f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && " + f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} " + f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} " + f"--host 0.0.0.0 --port {tool_server_port} &'", + detach=True, + user="pentester", + ) + + time.sleep(5) def _copy_local_directory_to_container(self, container: Container, local_path: str) -> None: import tarfile @@ -85,10 +204,10 @@ class DockerRuntime(AbstractRuntime): try: local_path_obj = Path(local_path).resolve() if not local_path_obj.exists() or not local_path_obj.is_dir(): - logger.warning(f"Local path does not exist or is not a directory: {local_path_obj}") + logger.warning(f"Local path does not exist or is not directory: {local_path_obj}") return - logger.info(f"Copying local directory {local_path_obj} to container {container.id}") + logger.info(f"Copying local directory {local_path_obj} to container") tar_buffer = BytesIO() with tarfile.open(fileobj=tar_buffer, mode="w") as tar: @@ -98,18 +217,14 @@ class DockerRuntime(AbstractRuntime): tar.add(item, arcname=arcname) tar_buffer.seek(0) - - container.put_archive("/shared_workspace", tar_buffer.getvalue()) + container.put_archive("/workspace", tar_buffer.getvalue()) container.exec_run( - "chown -R pentester:pentester /shared_workspace && chmod -R 755 /shared_workspace", + "chown -R pentester:pentester /workspace && chmod -R 755 /workspace", user="root", ) - logger.info( - f"Successfully copied {local_path_obj} to /shared_workspace in container " - f"{container.id}" - ) + logger.info("Successfully copied local directory to /workspace") except (OSError, DockerException): logger.exception("Failed to copy local directory to container") @@ -117,94 +232,56 @@ class DockerRuntime(AbstractRuntime): async def create_sandbox( self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None ) -> SandboxInfo: - sandbox = self._get_sandbox_by_agent_id(agent_id) - auth_token = existing_token or self._generate_sandbox_token() - scan_id = self._get_scan_id(agent_id) - volume_name = self._get_workspace_volume_name(scan_id) + container = self._get_or_create_scan_container(scan_id) - self._ensure_workspace_volume(volume_name) + source_copied_key = f"_source_copied_{scan_id}" + if local_source_path and not hasattr(self, source_copied_key): + self._copy_local_directory_to_container(container, local_source_path) + setattr(self, source_copied_key, True) - if not sandbox: - logger.info("Creating new Docker sandbox for agent %s", agent_id) - try: - tool_server_port = self._find_available_port() - caido_port = self._find_available_port() - - volumes_config = {volume_name: {"bind": "/shared_workspace", "mode": "rw"}} - container_name = f"strix-{agent_id}" - - sandbox = self.client.containers.run( - STRIX_IMAGE, - command="sleep infinity", - detach=True, - name=container_name, - hostname=container_name, - ports={ - f"{tool_server_port}/tcp": tool_server_port, - f"{caido_port}/tcp": caido_port, - }, - cap_add=["NET_ADMIN", "NET_RAW"], - labels={ - STRIX_AGENT_LABEL: agent_id, - STRIX_SCAN_LABEL: scan_id, - }, - environment={ - "PYTHONUNBUFFERED": "1", - "STRIX_AGENT_ID": agent_id, - "STRIX_SANDBOX_TOKEN": auth_token, - "STRIX_TOOL_SERVER_PORT": str(tool_server_port), - "CAIDO_PORT": str(caido_port), - }, - volumes=volumes_config, - tty=True, - ) - logger.info( - "Created new sandbox %s for agent %s with shared workspace %s", - sandbox.id, - agent_id, - volume_name, - ) - except DockerException as e: - raise RuntimeError(f"Failed to create Docker sandbox: {e}") from e - - assert sandbox is not None - if sandbox.status != "running": - sandbox.start() - time.sleep(15) - - if local_source_path and volume_name not in _initialized_volumes: - self._copy_local_directory_to_container(sandbox, local_source_path) - _initialized_volumes.add(volume_name) - - sandbox_id = sandbox.id - if sandbox_id is None: + container_id = container.id + if container_id is None: raise RuntimeError("Docker container ID is unexpectedly None") - tool_server_port_str = sandbox.attrs["Config"]["Env"][ - next( - ( - i - for i, s in enumerate(sandbox.attrs["Config"]["Env"]) - if s.startswith("STRIX_TOOL_SERVER_PORT=") - ), - -1, - ) - ].split("=")[1] - tool_server_port = int(tool_server_port_str) + token = existing_token if existing_token is not None else self._tool_server_token - api_url = await self.get_sandbox_url(sandbox_id, tool_server_port) + if self._tool_server_port is None or token is None: + raise RuntimeError("Tool server not initialized or no token available") + + api_url = await self.get_sandbox_url(container_id, self._tool_server_port) + + await self._register_agent_with_tool_server(api_url, agent_id, token) return { - "workspace_id": sandbox_id, + "workspace_id": container_id, "api_url": api_url, - "auth_token": auth_token, - "tool_server_port": tool_server_port, + "auth_token": token, + "tool_server_port": self._tool_server_port, + "agent_id": agent_id, } - async def get_sandbox_url(self, sandbox_id: str, port: int) -> str: + async def _register_agent_with_tool_server( + self, api_url: str, agent_id: str, token: str + ) -> None: + import httpx + try: - container = self.client.containers.get(sandbox_id) + async with httpx.AsyncClient() as client: + response = await client.post( + f"{api_url}/register_agent", + params={"agent_id": agent_id}, + headers={"Authorization": f"Bearer {token}"}, + timeout=30, + ) + response.raise_for_status() + logger.info(f"Registered agent {agent_id} with tool server") + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.warning(f"Failed to register agent {agent_id}: {e}") + + async def get_sandbox_url(self, container_id: str, port: int) -> str: + try: + container = self.client.containers.get(container_id) container.reload() host = "localhost" @@ -214,58 +291,25 @@ class DockerRuntime(AbstractRuntime): host = docker_host.split("://")[1].split(":")[0] except NotFound: - raise ValueError(f"Sandbox {sandbox_id} not found.") from None + raise ValueError(f"Container {container_id} not found.") from None except DockerException as e: - raise RuntimeError(f"Failed to get sandbox URL for {sandbox_id}: {e}") from e + raise RuntimeError(f"Failed to get container URL for {container_id}: {e}") from e else: return f"http://{host}:{port}" - async def destroy_sandbox(self, sandbox_id: str) -> None: - logger.info("Destroying Docker sandbox %s", sandbox_id) + async def destroy_sandbox(self, container_id: str) -> None: + logger.info("Destroying scan container %s", container_id) try: - container = self.client.containers.get(sandbox_id) - - scan_id = None - if container.labels and STRIX_SCAN_LABEL in container.labels: - scan_id = container.labels[STRIX_SCAN_LABEL] - + container = self.client.containers.get(container_id) container.stop() container.remove() - logger.info("Successfully destroyed sandbox %s", sandbox_id) + logger.info("Successfully destroyed container %s", container_id) - if scan_id: - await self._cleanup_workspace_if_empty(scan_id) + self._scan_container = None + self._tool_server_port = None + self._tool_server_token = None except NotFound: - logger.warning("Sandbox %s not found for destruction.", sandbox_id) + logger.warning("Container %s not found for destruction.", container_id) except DockerException as e: - logger.warning("Failed to destroy sandbox %s: %s", sandbox_id, e) - - async def _cleanup_workspace_if_empty(self, scan_id: str) -> None: - try: - volume_name = self._get_workspace_volume_name(scan_id) - - containers = self.client.containers.list( - all=True, filters={"label": f"{STRIX_SCAN_LABEL}={scan_id}"} - ) - - if not containers: - try: - volume = self.client.volumes.get(volume_name) - volume.remove() - logger.info( - f"Cleaned up workspace volume {volume_name} for completed scan {scan_id}" - ) - - _initialized_volumes.discard(volume_name) - - except NotFound: - logger.debug(f"Volume {volume_name} already removed") - except DockerException as e: - logger.warning(f"Failed to remove volume {volume_name}: {e}") - - except DockerException as e: - logger.warning("Error during workspace cleanup for scan %s: %s", scan_id, e) - - async def cleanup_scan_workspace(self, scan_id: str) -> None: - await self._cleanup_workspace_if_empty(scan_id) + logger.warning("Failed to destroy container %s: %s", container_id, e) diff --git a/strix/runtime/runtime.py b/strix/runtime/runtime.py index 493e9e5..328a757 100644 --- a/strix/runtime/runtime.py +++ b/strix/runtime/runtime.py @@ -7,6 +7,7 @@ class SandboxInfo(TypedDict): api_url: str auth_token: str | None tool_server_port: int + agent_id: str class AbstractRuntime(ABC): @@ -17,9 +18,9 @@ class AbstractRuntime(ABC): raise NotImplementedError @abstractmethod - async def get_sandbox_url(self, sandbox_id: str, port: int) -> str: + async def get_sandbox_url(self, container_id: str, port: int) -> str: raise NotImplementedError @abstractmethod - async def destroy_sandbox(self, sandbox_id: str) -> None: + async def destroy_sandbox(self, container_id: str) -> None: raise NotImplementedError diff --git a/strix/runtime/tool_server.py b/strix/runtime/tool_server.py index b826f79..6461f8c 100644 --- a/strix/runtime/tool_server.py +++ b/strix/runtime/tool_server.py @@ -1,7 +1,15 @@ +from __future__ import annotations + +import argparse +import asyncio import logging import os +import signal +import sys +from multiprocessing import Process, Queue from typing import Any +import uvicorn from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel, ValidationError @@ -11,20 +19,25 @@ SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" if not SANDBOX_MODE: raise RuntimeError("Tool server should only run in sandbox mode (STRIX_SANDBOX_MODE=true)") -EXPECTED_TOKEN = os.getenv("STRIX_SANDBOX_TOKEN") -if not EXPECTED_TOKEN: - raise RuntimeError("STRIX_SANDBOX_TOKEN environment variable is required in sandbox mode") +parser = argparse.ArgumentParser(description="Start Strix tool server") +parser.add_argument("--token", required=True, help="Authentication token") +parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec +parser.add_argument("--port", type=int, required=True, help="Port to bind to") + +args = parser.parse_args() +EXPECTED_TOKEN = args.token app = FastAPI() -logger = logging.getLogger(__name__) security = HTTPBearer() security_dependency = Depends(security) +agent_processes: dict[str, dict[str, Any]] = {} +agent_queues: dict[str, dict[str, Queue[Any]]] = {} + def verify_token(credentials: HTTPAuthorizationCredentials) -> str: if not credentials or credentials.scheme != "Bearer": - logger.warning("Authentication failed: Invalid or missing Bearer token scheme") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication scheme. Bearer token required.", @@ -32,18 +45,17 @@ def verify_token(credentials: HTTPAuthorizationCredentials) -> str: ) if credentials.credentials != EXPECTED_TOKEN: - logger.warning("Authentication failed: Invalid token provided from remote host") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication token", headers={"WWW-Authenticate": "Bearer"}, ) - logger.debug("Authentication successful for tool execution request") return credentials.credentials class ToolExecutionRequest(BaseModel): + agent_id: str tool_name: str kwargs: dict[str, Any] @@ -53,45 +65,141 @@ class ToolExecutionResponse(BaseModel): error: str | None = None +def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queue[Any]) -> None: + null_handler = logging.NullHandler() + + root_logger = logging.getLogger() + root_logger.handlers = [null_handler] + root_logger.setLevel(logging.CRITICAL) + + from strix.tools.argument_parser import ArgumentConversionError, convert_arguments + from strix.tools.registry import get_tool_by_name + + while True: + try: + request = request_queue.get() + + if request is None: + break + + tool_name = request["tool_name"] + kwargs = request["kwargs"] + + try: + tool_func = get_tool_by_name(tool_name) + if not tool_func: + response_queue.put({"error": f"Tool '{tool_name}' not found"}) + continue + + converted_kwargs = convert_arguments(tool_func, kwargs) + result = tool_func(**converted_kwargs) + + response_queue.put({"result": result}) + + except (ArgumentConversionError, ValidationError) as e: + response_queue.put({"error": f"Invalid arguments: {e}"}) + except (RuntimeError, ValueError, ImportError) as e: + response_queue.put({"error": f"Tool execution error: {e}"}) + + except (RuntimeError, ValueError, ImportError) as e: + response_queue.put({"error": f"Worker error: {e}"}) + + +def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]: + if agent_id not in agent_processes: + request_queue: Queue[Any] = Queue() + response_queue: Queue[Any] = Queue() + + process = Process( + target=agent_worker, args=(agent_id, request_queue, response_queue), daemon=True + ) + process.start() + + agent_processes[agent_id] = {"process": process, "pid": process.pid} + agent_queues[agent_id] = {"request": request_queue, "response": response_queue} + + return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"] + + @app.post("/execute", response_model=ToolExecutionResponse) async def execute_tool( request: ToolExecutionRequest, credentials: HTTPAuthorizationCredentials = security_dependency ) -> ToolExecutionResponse: verify_token(credentials) - from strix.tools.argument_parser import ArgumentConversionError, convert_arguments - from strix.tools.registry import get_tool_by_name + request_queue, response_queue = ensure_agent_process(request.agent_id) + + request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs}) try: - tool_func = get_tool_by_name(request.tool_name) - if not tool_func: - return ToolExecutionResponse(error=f"Tool '{request.tool_name}' not found") + loop = asyncio.get_event_loop() + response = await loop.run_in_executor(None, response_queue.get) - converted_kwargs = convert_arguments(tool_func, request.kwargs) + if "error" in response: + return ToolExecutionResponse(error=response["error"]) + return ToolExecutionResponse(result=response.get("result")) - result = tool_func(**converted_kwargs) + except (RuntimeError, ValueError, OSError) as e: + return ToolExecutionResponse(error=f"Worker error: {e}") - return ToolExecutionResponse(result=result) - except (ArgumentConversionError, ValidationError) as e: - logger.warning("Invalid tool arguments: %s", e) - return ToolExecutionResponse(error=f"Invalid arguments: {e}") - except TypeError as e: - logger.warning("Tool execution type error: %s", e) - return ToolExecutionResponse(error=f"Tool execution error: {e}") - except ValueError as e: - logger.warning("Tool execution value error: %s", e) - return ToolExecutionResponse(error=f"Tool execution error: {e}") - except Exception: - logger.exception("Unexpected error during tool execution") - return ToolExecutionResponse(error="Internal server error") +@app.post("/register_agent") +async def register_agent( + agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency +) -> dict[str, str]: + verify_token(credentials) + + ensure_agent_process(agent_id) + return {"status": "registered", "agent_id": agent_id} @app.get("/health") -async def health_check() -> dict[str, str]: +async def health_check() -> dict[str, Any]: return { "status": "healthy", "sandbox_mode": str(SANDBOX_MODE), "environment": "sandbox" if SANDBOX_MODE else "main", "auth_configured": "true" if EXPECTED_TOKEN else "false", + "active_agents": len(agent_processes), + "agents": list(agent_processes.keys()), } + + +def cleanup_all_agents() -> None: + for agent_id in list(agent_processes.keys()): + try: + agent_queues[agent_id]["request"].put(None) + process = agent_processes[agent_id]["process"] + + process.join(timeout=1) + + if process.is_alive(): + process.terminate() + process.join(timeout=1) + + if process.is_alive(): + process.kill() + + except (BrokenPipeError, EOFError, OSError): + pass + except (RuntimeError, ValueError) as e: + logging.getLogger(__name__).debug(f"Error during agent cleanup: {e}") + + +def signal_handler(_signum: int, _frame: Any) -> None: + signal.signal(signal.SIGPIPE, signal.SIG_IGN) if hasattr(signal, "SIGPIPE") else None + cleanup_all_agents() + sys.exit(0) + + +if hasattr(signal, "SIGPIPE"): + signal.signal(signal.SIGPIPE, signal.SIG_IGN) + +signal.signal(signal.SIGTERM, signal_handler) +signal.signal(signal.SIGINT, signal_handler) + +if __name__ == "__main__": + try: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + finally: + cleanup_all_agents() diff --git a/strix/tools/agents_graph/agents_graph_actions.py b/strix/tools/agents_graph/agents_graph_actions.py index f6b2255..15dd7d3 100644 --- a/strix/tools/agents_graph/agents_graph_actions.py +++ b/strix/tools/agents_graph/agents_graph_actions.py @@ -57,10 +57,10 @@ def _run_agent_in_thread( - Work independently with your own approach - Use agent_finish when complete to report back to parent - You are a SPECIALIST for this specific task - - The previous browser, sessions, proxy history, and files in /workspace were for your - parent agent. Do not depend on them. - - You are starting with a fresh context. Fresh proxy, browser, and files. - Only stuff in /shared_workspace is passed to you from context. + - You share the same container as other agents but have your own tool server instance + - All agents share /workspace directory and proxy history for better collaboration + - You can see files created by other agents and proxy traffic from previous work + - Build upon previous work but focus on your specific delegated task """ diff --git a/strix/tools/executor.py b/strix/tools/executor.py index 3f78da4..1ea15db 100644 --- a/strix/tools/executor.py +++ b/strix/tools/executor.py @@ -49,7 +49,10 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A server_url = await runtime.get_sandbox_url(agent_state.sandbox_id, tool_server_port) request_url = f"{server_url}/execute" + agent_id = getattr(agent_state, "agent_id", "unknown") + request_data = { + "agent_id": agent_id, "tool_name": tool_name, "kwargs": kwargs, } diff --git a/strix/tools/terminal/terminal_actions_schema.xml b/strix/tools/terminal/terminal_actions_schema.xml index eda6edc..d4e2fec 100644 --- a/strix/tools/terminal/terminal_actions_schema.xml +++ b/strix/tools/terminal/terminal_actions_schema.xml @@ -3,7 +3,7 @@ Execute a bash command in a persistent terminal session. The terminal maintains state (environment variables, current directory, running processes) between commands. - The bash command to execute. Cannot be empty - must provide a valid command or special key sequence. + The bash command to execute. Can be empty to check output of running commands (will wait for timeout period to collect output). Supported special keys and sequences (based on official tmux key names): - Control sequences: C-c, C-d, C-z, C-a, C-e, C-k, C-l, C-u, C-w, etc. (also ^c, ^d, etc.) @@ -16,13 +16,16 @@ - Shift sequences: S-key (e.g., S-F6, S-Tab, S-Left) - Combined modifiers: C-S-key, C-M-key, S-M-key, etc. + Special keys work automatically - no need to set is_input=true for keys like C-c, C-d, etc. These are useful for interacting with vim, emacs, REPLs, and other interactive applications. - If true, the command is sent as input to a currently running process. If false (default), the command is executed as a new bash command. Use this to interact with running processes. + If true, the command is sent as input to a currently running process. If false (default), the command is executed as a new bash command. + Note: Special keys (C-c, C-d, etc.) automatically work when a process is running - you don't need to set is_input=true for them. + Use is_input=true for regular text input to running processes. - Optional timeout in seconds for command execution. If not provided, uses default timeout behavior. Set to higher values for long-running commands like installations or tests. + Optional timeout in seconds for command execution. If not provided, uses default timeout behavior. Set to higher values for long-running commands like installations or tests. Default is 10 seconds. Identifier for the terminal session. Defaults to "default". Use different IDs to manage multiple concurrent terminal sessions. @@ -44,7 +47,7 @@ - exit_code: Exit code of the command (only for completed commands) - command: The executed command - terminal_id: The terminal session ID - - status: Command status ('completed', 'timeout', 'running') + - status: Command status ('completed' or 'running') - working_dir: Current working directory after command execution @@ -56,22 +59,25 @@ && or ; operators, or make separate tool calls. 3. LONG-RUNNING COMMANDS: - - For commands that run indefinitely, run them in background: 'python app.py > server.log 2>&1 &' - - For commands that take time, set appropriate timeout parameter - - Use is_input=true to interact with running processes + - Commands never get killed automatically - they keep running in background + - Set timeout to control how long to wait for output before returning + - Use empty command "" to check progress (waits for timeout period to collect output) + - Use C-c, C-d, C-z to interrupt processes (works automatically, no is_input needed) 4. TIMEOUT HANDLING: - - Commands have a default soft timeout (30 seconds of no output changes) - - Set custom timeout for longer operations - - When timeout occurs, you can send empty command to get more output - - Use control sequences (C-c, C-d, C-z) to interrupt processes + - Timeout controls how long to wait before returning current output + - Commands are NEVER killed on timeout - they keep running + - After timeout, you can run new commands or check progress with empty command + - All commands return status "completed" - you have full control 5. MULTIPLE TERMINALS: Use different terminal_id values to run multiple concurrent sessions. - 6. INTERACTIVE PROCESSES: Use is_input=true to send input to running processes like: - - Interactive shells, REPLs, or prompts - - Long-running applications waiting for input - - Background processes that need interaction + 6. INTERACTIVE PROCESSES: + - Special keys (C-c, C-d, etc.) work automatically when a process is running + - Use is_input=true for regular text input to running processes like: + * Interactive shells, REPLs, or prompts + * Long-running applications waiting for input + * Background processes that need interaction - Use no_enter=true for stuff like Vim navigation, password typing, or multi-step commands 7. WORKING DIRECTORY: The terminal tracks and returns the current working directory. @@ -92,6 +98,12 @@ 120 + # Check progress of running command (waits for timeout to collect output) + + + 5 + + # Start a background service python app.py > server.log 2>&1 & @@ -103,7 +115,7 @@ true - # Interrupt a running process + # Interrupt a running process (special keys work automatically) C-c diff --git a/strix/tools/terminal/terminal_session.py b/strix/tools/terminal/terminal_session.py index 711d340..2ed9b74 100644 --- a/strix/tools/terminal/terminal_session.py +++ b/strix/tools/terminal/terminal_session.py @@ -33,7 +33,6 @@ class TerminalSession: self.work_dir = str(Path(work_dir).resolve()) self._closed = False self._cwd = self.work_dir - self.NO_CHANGE_TIMEOUT_SECONDS = 30 self.server: libtmux.Server | None = None self.session: libtmux.Session | None = None @@ -200,55 +199,126 @@ class TerminalSession: except (ValueError, IndexError): return None - def execute( - self, command: str, is_input: bool = False, timeout: float = 30.0, no_enter: bool = False + def _handle_empty_command( + self, + cur_pane_output: str, + ps1_matches: list[re.Match[str]], + is_command_running: bool, + timeout: float, ) -> dict[str, Any]: - if not self._initialized: - raise RuntimeError("Bash session is not initialized") - - if command == "" or command.strip() == "": + if not is_command_running: + raw_command_output = self._combine_outputs_between_matches(cur_pane_output, ps1_matches) + command_output = self._get_command_output("", raw_command_output) return { - "content": ( - "Command cannot be empty - must provide a valid command or control sequence" - ), + "content": command_output, + "status": "completed", + "exit_code": 0, + "working_dir": self._cwd, + } + + start_time = time.time() + last_pane_output = cur_pane_output + + while True: + cur_pane_output = self._get_pane_content() + ps1_matches = self._matches_ps1_metadata(cur_pane_output) + + if cur_pane_output.rstrip().endswith(self.PS1_END.rstrip()) or len(ps1_matches) > 0: + exit_code = self._extract_exit_code_from_matches(ps1_matches) + raw_command_output = self._combine_outputs_between_matches( + cur_pane_output, ps1_matches + ) + command_output = self._get_command_output("", raw_command_output) + self.prev_status = BashCommandStatus.COMPLETED + self.prev_output = "" + self._ready_for_next_command() + return { + "content": command_output, + "status": "completed", + "exit_code": exit_code or 0, + "working_dir": self._cwd, + } + + elapsed_time = time.time() - start_time + if elapsed_time >= timeout: + raw_command_output = self._combine_outputs_between_matches( + cur_pane_output, ps1_matches + ) + command_output = self._get_command_output("", raw_command_output) + return { + "content": command_output + + f"\n[Command still running after {timeout}s - showing output so far]", + "status": "running", + "exit_code": None, + "working_dir": self._cwd, + } + + if cur_pane_output != last_pane_output: + last_pane_output = cur_pane_output + + time.sleep(self.POLL_INTERVAL) + + def _handle_input_command( + self, command: str, no_enter: bool, is_command_running: bool + ) -> dict[str, Any]: + if not is_command_running: + return { + "content": "No command is currently running. Cannot send input.", "status": "error", "exit_code": None, "working_dir": self._cwd, } - if ( - self.prev_status - in { - BashCommandStatus.HARD_TIMEOUT, - BashCommandStatus.NO_CHANGE_TIMEOUT, - } - and not is_input - and command != "" - ): + if not self.pane: + raise RuntimeError("Terminal session not properly initialized") + + is_special_key = self._is_special_key(command) + should_add_enter = not is_special_key and not no_enter + self.pane.send_keys(command, enter=should_add_enter) + + time.sleep(2) + cur_pane_output = self._get_pane_content() + ps1_matches = self._matches_ps1_metadata(cur_pane_output) + raw_command_output = self._combine_outputs_between_matches(cur_pane_output, ps1_matches) + command_output = self._get_command_output(command, raw_command_output) + + is_still_running = not ( + cur_pane_output.rstrip().endswith(self.PS1_END.rstrip()) or len(ps1_matches) > 0 + ) + + if is_still_running: return { - "content": ( - f'Previous command still running. Cannot execute "{command}". ' - "Use is_input=True to interact with running process." - ), - "status": "error", + "content": command_output, + "status": "running", "exit_code": None, "working_dir": self._cwd, } + exit_code = self._extract_exit_code_from_matches(ps1_matches) + self.prev_status = BashCommandStatus.COMPLETED + self.prev_output = "" + self._ready_for_next_command() + return { + "content": command_output, + "status": "completed", + "exit_code": exit_code or 0, + "working_dir": self._cwd, + } + + def _execute_new_command(self, command: str, no_enter: bool, timeout: float) -> dict[str, Any]: + if not self.pane: + raise RuntimeError("Terminal session not properly initialized") + initial_pane_output = self._get_pane_content() initial_ps1_matches = self._matches_ps1_metadata(initial_pane_output) initial_ps1_count = len(initial_ps1_matches) start_time = time.time() - last_change_time = start_time last_pane_output = initial_pane_output - if command != "": - if not self.pane: - raise RuntimeError("Terminal session not properly initialized") - is_special_key = self._is_special_key(command) - should_add_enter = not is_special_key and not no_enter - self.pane.send_keys(command, enter=should_add_enter) + is_special_key = self._is_special_key(command) + should_add_enter = not is_special_key and not no_enter + self.pane.send_keys(command, enter=should_add_enter) while True: cur_pane_output = self._get_pane_content() @@ -257,7 +327,6 @@ class TerminalSession: if cur_pane_output != last_pane_output: last_pane_output = cur_pane_output - last_change_time = time.time() if current_ps1_count > initial_ps1_count or cur_pane_output.rstrip().endswith( self.PS1_END.rstrip() @@ -283,26 +352,6 @@ class TerminalSession: "working_dir": self._cwd, } - time_since_last_change = time.time() - last_change_time - if time_since_last_change >= self.NO_CHANGE_TIMEOUT_SECONDS: - raw_command_output = self._combine_outputs_between_matches( - cur_pane_output, ps1_matches - ) - command_output = self._get_command_output( - command, - raw_command_output, - continue_prefix="[Below is the output of the previous command.]\n", - ) - self.prev_status = BashCommandStatus.NO_CHANGE_TIMEOUT - - return { - "content": command_output + f"\n[Command timed out - no output change for " - f"{self.NO_CHANGE_TIMEOUT_SECONDS} seconds]", - "status": "timeout", - "exit_code": -1, - "working_dir": self._cwd, - } - elapsed_time = time.time() - start_time if elapsed_time >= timeout: raw_command_output = self._combine_outputs_between_matches( @@ -313,17 +362,59 @@ class TerminalSession: raw_command_output, continue_prefix="[Below is the output of the previous command.]\n", ) - self.prev_status = BashCommandStatus.HARD_TIMEOUT + self.prev_status = BashCommandStatus.CONTINUE + timeout_msg = ( + f"\n[Command still running after {timeout}s - showing output so far. " + "Use C-c to interrupt if needed.]" + ) return { - "content": command_output + f"\n[Command timed out after {timeout} seconds]", - "status": "timeout", - "exit_code": -1, + "content": command_output + timeout_msg, + "status": "running", + "exit_code": None, "working_dir": self._cwd, } time.sleep(self.POLL_INTERVAL) + def execute( + self, command: str, is_input: bool = False, timeout: float = 10.0, no_enter: bool = False + ) -> dict[str, Any]: + if not self._initialized: + raise RuntimeError("Bash session is not initialized") + + cur_pane_output = self._get_pane_content() + ps1_matches = self._matches_ps1_metadata(cur_pane_output) + is_command_running = not ( + cur_pane_output.rstrip().endswith(self.PS1_END.rstrip()) or len(ps1_matches) > 0 + ) + + if command.strip() == "": + return self._handle_empty_command( + cur_pane_output, ps1_matches, is_command_running, timeout + ) + + is_special_key = self._is_special_key(command) + + if is_input: + return self._handle_input_command(command, no_enter, is_command_running) + + if is_special_key and is_command_running: + return self._handle_input_command(command, no_enter, is_command_running) + + if is_command_running: + return { + "content": ( + "A command is already running. Use is_input=true to send input to it, " + "or interrupt it first (e.g., with C-c)." + ), + "status": "error", + "exit_code": None, + "working_dir": self._cwd, + } + + return self._execute_new_command(command, no_enter, timeout) + def _ready_for_next_command(self) -> None: self._clear_screen()