Running all agents under same container (#12)

This commit is contained in:
Ahmed Allam
2025-08-18 13:58:38 -07:00
committed by GitHub
parent 198a5e4a61
commit cb57426cc6
13 changed files with 546 additions and 292 deletions

View File

@@ -7,19 +7,15 @@ from pathlib import Path
from typing import cast
import docker
from docker.errors import DockerException, NotFound
from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container
from .runtime import AbstractRuntime, SandboxInfo
STRIX_AGENT_LABEL = "StrixAgent_ID"
STRIX_SCAN_LABEL = "StrixScan_ID"
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.4")
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10")
logger = logging.getLogger(__name__)
_initialized_volumes: set[str] = set()
class DockerRuntime(AbstractRuntime):
def __init__(self) -> None:
@@ -29,9 +25,18 @@ class DockerRuntime(AbstractRuntime):
logger.exception("Failed to connect to Docker daemon")
raise RuntimeError("Docker is not available or not configured correctly.") from e
self._scan_container: Container | None = None
self._tool_server_port: int | None = None
self._tool_server_token: str | None = None
def _generate_sandbox_token(self) -> str:
return secrets.token_urlsafe(32)
def _find_available_port(self) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return cast("int", s.getsockname()[1])
def _get_scan_id(self, agent_id: str) -> str:
try:
from strix.cli.tracer import get_global_tracer
@@ -46,37 +51,151 @@ class DockerRuntime(AbstractRuntime):
return f"scan-{agent_id.split('-')[0]}"
def _find_available_port(self) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return cast("int", s.getsockname()[1])
def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None:
def _validate_image(image: docker.models.images.Image) -> None:
if not image.id or not image.attrs:
raise ImageNotFound(f"Image {image_name} metadata incomplete")
def _get_workspace_volume_name(self, scan_id: str) -> str:
return f"strix-workspace-{scan_id}"
for attempt in range(max_retries):
try:
image = self.client.images.get(image_name)
_validate_image(image)
except ImageNotFound:
if attempt == max_retries - 1:
logger.exception(f"Image {image_name} not found after {max_retries} attempts")
raise
logger.warning(f"Image {image_name} not ready, attempt {attempt + 1}/{max_retries}")
time.sleep(2**attempt)
except DockerException:
if attempt == max_retries - 1:
logger.exception(f"Failed to verify image {image_name}")
raise
logger.warning(f"Docker error verifying image, attempt {attempt + 1}/{max_retries}")
time.sleep(2**attempt)
else:
logger.debug(f"Image {image_name} verified as available")
return
def _get_sandbox_by_agent_id(self, agent_id: str) -> Container | None:
try:
containers = self.client.containers.list(
filters={"label": f"{STRIX_AGENT_LABEL}={agent_id}"}
)
if not containers:
return None
if len(containers) > 1:
logger.warning(
"Multiple sandboxes found for agent ID %s, using the first one.", agent_id
def _create_container_with_retry(self, scan_id: str, max_retries: int = 3) -> Container:
last_exception = None
for attempt in range(max_retries):
try:
self._verify_image_available(STRIX_IMAGE)
caido_port = self._find_available_port()
tool_server_port = self._find_available_port()
tool_server_token = self._generate_sandbox_token()
self._tool_server_port = tool_server_port
self._tool_server_token = tool_server_token
container = self.client.containers.run(
STRIX_IMAGE,
command="sleep infinity",
detach=True,
name=f"strix-scan-{scan_id}",
hostname=f"strix-scan-{scan_id}",
ports={
f"{caido_port}/tcp": caido_port,
f"{tool_server_port}/tcp": tool_server_port,
},
cap_add=["NET_ADMIN", "NET_RAW"],
labels={"strix-scan-id": scan_id},
environment={
"PYTHONUNBUFFERED": "1",
"CAIDO_PORT": str(caido_port),
"TOOL_SERVER_PORT": str(tool_server_port),
"TOOL_SERVER_TOKEN": tool_server_token,
},
tty=True,
)
return cast("Container", containers[0])
except DockerException as e:
logger.warning("Failed to get sandbox by agent ID %s: %s", agent_id, e)
return None
def _ensure_workspace_volume(self, volume_name: str) -> None:
self._scan_container = container
logger.info("Created container %s for scan %s", container.id, scan_id)
self._initialize_container(
container, caido_port, tool_server_port, tool_server_token
)
except DockerException as e:
last_exception = e
if attempt == max_retries - 1:
logger.exception(f"Failed to create container after {max_retries} attempts")
break
logger.warning(f"Container creation attempt {attempt + 1}/{max_retries} failed")
self._tool_server_port = None
self._tool_server_token = None
sleep_time = (2**attempt) + (0.1 * attempt)
time.sleep(sleep_time)
else:
return container
raise RuntimeError(
f"Failed to create Docker container after {max_retries} attempts: {last_exception}"
) from last_exception
def _get_or_create_scan_container(self, scan_id: str) -> Container:
if self._scan_container:
try:
self._scan_container.reload()
if self._scan_container.status == "running":
return self._scan_container
except NotFound:
self._scan_container = None
self._tool_server_port = None
self._tool_server_token = None
try:
self.client.volumes.get(volume_name)
logger.info(f"Using existing workspace volume: {volume_name}")
except NotFound:
self.client.volumes.create(name=volume_name, driver="local")
logger.info(f"Created new workspace volume: {volume_name}")
containers = self.client.containers.list(filters={"label": f"strix-scan-id={scan_id}"})
if containers:
container = cast("Container", containers[0])
if container.status != "running":
container.start()
time.sleep(2)
self._scan_container = container
for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_PORT="):
self._tool_server_port = int(env_var.split("=")[1])
elif env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=")[1]
return container
except DockerException as e:
logger.warning("Failed to find existing container for scan %s: %s", scan_id, e)
logger.info("Creating new Docker container for scan %s", scan_id)
return self._create_container_with_retry(scan_id)
def _initialize_container(
self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str
) -> None:
logger.info("Initializing Caido proxy on port %s", caido_port)
result = container.exec_run(
f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'",
detach=False,
)
time.sleep(5)
result = container.exec_run(
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester"
)
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
container.exec_run(
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
f"--host 0.0.0.0 --port {tool_server_port} &'",
detach=True,
user="pentester",
)
time.sleep(5)
def _copy_local_directory_to_container(self, container: Container, local_path: str) -> None:
import tarfile
@@ -85,10 +204,10 @@ class DockerRuntime(AbstractRuntime):
try:
local_path_obj = Path(local_path).resolve()
if not local_path_obj.exists() or not local_path_obj.is_dir():
logger.warning(f"Local path does not exist or is not a directory: {local_path_obj}")
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
return
logger.info(f"Copying local directory {local_path_obj} to container {container.id}")
logger.info(f"Copying local directory {local_path_obj} to container")
tar_buffer = BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
@@ -98,18 +217,14 @@ class DockerRuntime(AbstractRuntime):
tar.add(item, arcname=arcname)
tar_buffer.seek(0)
container.put_archive("/shared_workspace", tar_buffer.getvalue())
container.put_archive("/workspace", tar_buffer.getvalue())
container.exec_run(
"chown -R pentester:pentester /shared_workspace && chmod -R 755 /shared_workspace",
"chown -R pentester:pentester /workspace && chmod -R 755 /workspace",
user="root",
)
logger.info(
f"Successfully copied {local_path_obj} to /shared_workspace in container "
f"{container.id}"
)
logger.info("Successfully copied local directory to /workspace")
except (OSError, DockerException):
logger.exception("Failed to copy local directory to container")
@@ -117,94 +232,56 @@ class DockerRuntime(AbstractRuntime):
async def create_sandbox(
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None
) -> SandboxInfo:
sandbox = self._get_sandbox_by_agent_id(agent_id)
auth_token = existing_token or self._generate_sandbox_token()
scan_id = self._get_scan_id(agent_id)
volume_name = self._get_workspace_volume_name(scan_id)
container = self._get_or_create_scan_container(scan_id)
self._ensure_workspace_volume(volume_name)
source_copied_key = f"_source_copied_{scan_id}"
if local_source_path and not hasattr(self, source_copied_key):
self._copy_local_directory_to_container(container, local_source_path)
setattr(self, source_copied_key, True)
if not sandbox:
logger.info("Creating new Docker sandbox for agent %s", agent_id)
try:
tool_server_port = self._find_available_port()
caido_port = self._find_available_port()
volumes_config = {volume_name: {"bind": "/shared_workspace", "mode": "rw"}}
container_name = f"strix-{agent_id}"
sandbox = self.client.containers.run(
STRIX_IMAGE,
command="sleep infinity",
detach=True,
name=container_name,
hostname=container_name,
ports={
f"{tool_server_port}/tcp": tool_server_port,
f"{caido_port}/tcp": caido_port,
},
cap_add=["NET_ADMIN", "NET_RAW"],
labels={
STRIX_AGENT_LABEL: agent_id,
STRIX_SCAN_LABEL: scan_id,
},
environment={
"PYTHONUNBUFFERED": "1",
"STRIX_AGENT_ID": agent_id,
"STRIX_SANDBOX_TOKEN": auth_token,
"STRIX_TOOL_SERVER_PORT": str(tool_server_port),
"CAIDO_PORT": str(caido_port),
},
volumes=volumes_config,
tty=True,
)
logger.info(
"Created new sandbox %s for agent %s with shared workspace %s",
sandbox.id,
agent_id,
volume_name,
)
except DockerException as e:
raise RuntimeError(f"Failed to create Docker sandbox: {e}") from e
assert sandbox is not None
if sandbox.status != "running":
sandbox.start()
time.sleep(15)
if local_source_path and volume_name not in _initialized_volumes:
self._copy_local_directory_to_container(sandbox, local_source_path)
_initialized_volumes.add(volume_name)
sandbox_id = sandbox.id
if sandbox_id is None:
container_id = container.id
if container_id is None:
raise RuntimeError("Docker container ID is unexpectedly None")
tool_server_port_str = sandbox.attrs["Config"]["Env"][
next(
(
i
for i, s in enumerate(sandbox.attrs["Config"]["Env"])
if s.startswith("STRIX_TOOL_SERVER_PORT=")
),
-1,
)
].split("=")[1]
tool_server_port = int(tool_server_port_str)
token = existing_token if existing_token is not None else self._tool_server_token
api_url = await self.get_sandbox_url(sandbox_id, tool_server_port)
if self._tool_server_port is None or token is None:
raise RuntimeError("Tool server not initialized or no token available")
api_url = await self.get_sandbox_url(container_id, self._tool_server_port)
await self._register_agent_with_tool_server(api_url, agent_id, token)
return {
"workspace_id": sandbox_id,
"workspace_id": container_id,
"api_url": api_url,
"auth_token": auth_token,
"tool_server_port": tool_server_port,
"auth_token": token,
"tool_server_port": self._tool_server_port,
"agent_id": agent_id,
}
async def get_sandbox_url(self, sandbox_id: str, port: int) -> str:
async def _register_agent_with_tool_server(
self, api_url: str, agent_id: str, token: str
) -> None:
import httpx
try:
container = self.client.containers.get(sandbox_id)
async with httpx.AsyncClient() as client:
response = await client.post(
f"{api_url}/register_agent",
params={"agent_id": agent_id},
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
response.raise_for_status()
logger.info(f"Registered agent {agent_id} with tool server")
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.warning(f"Failed to register agent {agent_id}: {e}")
async def get_sandbox_url(self, container_id: str, port: int) -> str:
try:
container = self.client.containers.get(container_id)
container.reload()
host = "localhost"
@@ -214,58 +291,25 @@ class DockerRuntime(AbstractRuntime):
host = docker_host.split("://")[1].split(":")[0]
except NotFound:
raise ValueError(f"Sandbox {sandbox_id} not found.") from None
raise ValueError(f"Container {container_id} not found.") from None
except DockerException as e:
raise RuntimeError(f"Failed to get sandbox URL for {sandbox_id}: {e}") from e
raise RuntimeError(f"Failed to get container URL for {container_id}: {e}") from e
else:
return f"http://{host}:{port}"
async def destroy_sandbox(self, sandbox_id: str) -> None:
logger.info("Destroying Docker sandbox %s", sandbox_id)
async def destroy_sandbox(self, container_id: str) -> None:
logger.info("Destroying scan container %s", container_id)
try:
container = self.client.containers.get(sandbox_id)
scan_id = None
if container.labels and STRIX_SCAN_LABEL in container.labels:
scan_id = container.labels[STRIX_SCAN_LABEL]
container = self.client.containers.get(container_id)
container.stop()
container.remove()
logger.info("Successfully destroyed sandbox %s", sandbox_id)
logger.info("Successfully destroyed container %s", container_id)
if scan_id:
await self._cleanup_workspace_if_empty(scan_id)
self._scan_container = None
self._tool_server_port = None
self._tool_server_token = None
except NotFound:
logger.warning("Sandbox %s not found for destruction.", sandbox_id)
logger.warning("Container %s not found for destruction.", container_id)
except DockerException as e:
logger.warning("Failed to destroy sandbox %s: %s", sandbox_id, e)
async def _cleanup_workspace_if_empty(self, scan_id: str) -> None:
try:
volume_name = self._get_workspace_volume_name(scan_id)
containers = self.client.containers.list(
all=True, filters={"label": f"{STRIX_SCAN_LABEL}={scan_id}"}
)
if not containers:
try:
volume = self.client.volumes.get(volume_name)
volume.remove()
logger.info(
f"Cleaned up workspace volume {volume_name} for completed scan {scan_id}"
)
_initialized_volumes.discard(volume_name)
except NotFound:
logger.debug(f"Volume {volume_name} already removed")
except DockerException as e:
logger.warning(f"Failed to remove volume {volume_name}: {e}")
except DockerException as e:
logger.warning("Error during workspace cleanup for scan %s: %s", scan_id, e)
async def cleanup_scan_workspace(self, scan_id: str) -> None:
await self._cleanup_workspace_if_empty(scan_id)
logger.warning("Failed to destroy container %s: %s", container_id, e)

View File

@@ -7,6 +7,7 @@ class SandboxInfo(TypedDict):
api_url: str
auth_token: str | None
tool_server_port: int
agent_id: str
class AbstractRuntime(ABC):
@@ -17,9 +18,9 @@ class AbstractRuntime(ABC):
raise NotImplementedError
@abstractmethod
async def get_sandbox_url(self, sandbox_id: str, port: int) -> str:
async def get_sandbox_url(self, container_id: str, port: int) -> str:
raise NotImplementedError
@abstractmethod
async def destroy_sandbox(self, sandbox_id: str) -> None:
async def destroy_sandbox(self, container_id: str) -> None:
raise NotImplementedError

View File

@@ -1,7 +1,15 @@
from __future__ import annotations
import argparse
import asyncio
import logging
import os
import signal
import sys
from multiprocessing import Process, Queue
from typing import Any
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, ValidationError
@@ -11,20 +19,25 @@ SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not SANDBOX_MODE:
raise RuntimeError("Tool server should only run in sandbox mode (STRIX_SANDBOX_MODE=true)")
EXPECTED_TOKEN = os.getenv("STRIX_SANDBOX_TOKEN")
if not EXPECTED_TOKEN:
raise RuntimeError("STRIX_SANDBOX_TOKEN environment variable is required in sandbox mode")
parser = argparse.ArgumentParser(description="Start Strix tool server")
parser.add_argument("--token", required=True, help="Authentication token")
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
parser.add_argument("--port", type=int, required=True, help="Port to bind to")
args = parser.parse_args()
EXPECTED_TOKEN = args.token
app = FastAPI()
logger = logging.getLogger(__name__)
security = HTTPBearer()
security_dependency = Depends(security)
agent_processes: dict[str, dict[str, Any]] = {}
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
if not credentials or credentials.scheme != "Bearer":
logger.warning("Authentication failed: Invalid or missing Bearer token scheme")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme. Bearer token required.",
@@ -32,18 +45,17 @@ def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
)
if credentials.credentials != EXPECTED_TOKEN:
logger.warning("Authentication failed: Invalid token provided from remote host")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
logger.debug("Authentication successful for tool execution request")
return credentials.credentials
class ToolExecutionRequest(BaseModel):
agent_id: str
tool_name: str
kwargs: dict[str, Any]
@@ -53,45 +65,141 @@ class ToolExecutionResponse(BaseModel):
error: str | None = None
def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
null_handler = logging.NullHandler()
root_logger = logging.getLogger()
root_logger.handlers = [null_handler]
root_logger.setLevel(logging.CRITICAL)
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
from strix.tools.registry import get_tool_by_name
while True:
try:
request = request_queue.get()
if request is None:
break
tool_name = request["tool_name"]
kwargs = request["kwargs"]
try:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
response_queue.put({"error": f"Tool '{tool_name}' not found"})
continue
converted_kwargs = convert_arguments(tool_func, kwargs)
result = tool_func(**converted_kwargs)
response_queue.put({"result": result})
except (ArgumentConversionError, ValidationError) as e:
response_queue.put({"error": f"Invalid arguments: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Tool execution error: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Worker error: {e}"})
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
if agent_id not in agent_processes:
request_queue: Queue[Any] = Queue()
response_queue: Queue[Any] = Queue()
process = Process(
target=agent_worker, args=(agent_id, request_queue, response_queue), daemon=True
)
process.start()
agent_processes[agent_id] = {"process": process, "pid": process.pid}
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
@app.post("/execute", response_model=ToolExecutionResponse)
async def execute_tool(
request: ToolExecutionRequest, credentials: HTTPAuthorizationCredentials = security_dependency
) -> ToolExecutionResponse:
verify_token(credentials)
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
from strix.tools.registry import get_tool_by_name
request_queue, response_queue = ensure_agent_process(request.agent_id)
request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs})
try:
tool_func = get_tool_by_name(request.tool_name)
if not tool_func:
return ToolExecutionResponse(error=f"Tool '{request.tool_name}' not found")
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, response_queue.get)
converted_kwargs = convert_arguments(tool_func, request.kwargs)
if "error" in response:
return ToolExecutionResponse(error=response["error"])
return ToolExecutionResponse(result=response.get("result"))
result = tool_func(**converted_kwargs)
except (RuntimeError, ValueError, OSError) as e:
return ToolExecutionResponse(error=f"Worker error: {e}")
return ToolExecutionResponse(result=result)
except (ArgumentConversionError, ValidationError) as e:
logger.warning("Invalid tool arguments: %s", e)
return ToolExecutionResponse(error=f"Invalid arguments: {e}")
except TypeError as e:
logger.warning("Tool execution type error: %s", e)
return ToolExecutionResponse(error=f"Tool execution error: {e}")
except ValueError as e:
logger.warning("Tool execution value error: %s", e)
return ToolExecutionResponse(error=f"Tool execution error: {e}")
except Exception:
logger.exception("Unexpected error during tool execution")
return ToolExecutionResponse(error="Internal server error")
@app.post("/register_agent")
async def register_agent(
agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency
) -> dict[str, str]:
verify_token(credentials)
ensure_agent_process(agent_id)
return {"status": "registered", "agent_id": agent_id}
@app.get("/health")
async def health_check() -> dict[str, str]:
async def health_check() -> dict[str, Any]:
return {
"status": "healthy",
"sandbox_mode": str(SANDBOX_MODE),
"environment": "sandbox" if SANDBOX_MODE else "main",
"auth_configured": "true" if EXPECTED_TOKEN else "false",
"active_agents": len(agent_processes),
"agents": list(agent_processes.keys()),
}
def cleanup_all_agents() -> None:
for agent_id in list(agent_processes.keys()):
try:
agent_queues[agent_id]["request"].put(None)
process = agent_processes[agent_id]["process"]
process.join(timeout=1)
if process.is_alive():
process.terminate()
process.join(timeout=1)
if process.is_alive():
process.kill()
except (BrokenPipeError, EOFError, OSError):
pass
except (RuntimeError, ValueError) as e:
logging.getLogger(__name__).debug(f"Error during agent cleanup: {e}")
def signal_handler(_signum: int, _frame: Any) -> None:
signal.signal(signal.SIGPIPE, signal.SIG_IGN) if hasattr(signal, "SIGPIPE") else None
cleanup_all_agents()
sys.exit(0)
if hasattr(signal, "SIGPIPE"):
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if __name__ == "__main__":
try:
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
finally:
cleanup_all_agents()