Open-source release for Alpha version

This commit is contained in:
Ahmed Allam
2025-08-08 20:36:44 -07:00
commit 81ac98e8b9
105 changed files with 22125 additions and 0 deletions

19
strix/runtime/__init__.py Normal file
View File

@@ -0,0 +1,19 @@
import os
from .runtime import AbstractRuntime
def get_runtime() -> AbstractRuntime:
runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker")
if runtime_backend == "docker":
from .docker_runtime import DockerRuntime
return DockerRuntime()
raise ValueError(
f"Unsupported runtime backend: {runtime_backend}. Only 'docker' is supported for now."
)
__all__ = ["AbstractRuntime", "get_runtime"]

View File

@@ -0,0 +1,271 @@
import logging
import os
import secrets
import socket
import time
from pathlib import Path
from typing import cast
import docker
from docker.errors import DockerException, NotFound
from docker.models.containers import Container
from .runtime import AbstractRuntime, SandboxInfo
STRIX_AGENT_LABEL = "StrixAgent_ID"
STRIX_SCAN_LABEL = "StrixScan_ID"
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.4")
logger = logging.getLogger(__name__)
_initialized_volumes: set[str] = set()
class DockerRuntime(AbstractRuntime):
def __init__(self) -> None:
try:
self.client = docker.from_env()
except DockerException as e:
logger.exception("Failed to connect to Docker daemon")
raise RuntimeError("Docker is not available or not configured correctly.") from e
def _generate_sandbox_token(self) -> str:
return secrets.token_urlsafe(32)
def _get_scan_id(self, agent_id: str) -> str:
try:
from strix.cli.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer and tracer.scan_config:
return str(tracer.scan_config.get("scan_id", "default-scan"))
except ImportError:
logger.debug("Failed to import tracer, using fallback scan ID")
except AttributeError:
logger.debug("Tracer missing scan_config, using fallback scan ID")
return f"scan-{agent_id.split('-')[0]}"
def _find_available_port(self) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return cast("int", s.getsockname()[1])
def _get_workspace_volume_name(self, scan_id: str) -> str:
return f"strix-workspace-{scan_id}"
def _get_sandbox_by_agent_id(self, agent_id: str) -> Container | None:
try:
containers = self.client.containers.list(
filters={"label": f"{STRIX_AGENT_LABEL}={agent_id}"}
)
if not containers:
return None
if len(containers) > 1:
logger.warning(
"Multiple sandboxes found for agent ID %s, using the first one.", agent_id
)
return cast("Container", containers[0])
except DockerException as e:
logger.warning("Failed to get sandbox by agent ID %s: %s", agent_id, e)
return None
def _ensure_workspace_volume(self, volume_name: str) -> None:
try:
self.client.volumes.get(volume_name)
logger.info(f"Using existing workspace volume: {volume_name}")
except NotFound:
self.client.volumes.create(name=volume_name, driver="local")
logger.info(f"Created new workspace volume: {volume_name}")
def _copy_local_directory_to_container(self, container: Container, local_path: str) -> None:
import tarfile
from io import BytesIO
try:
local_path_obj = Path(local_path).resolve()
if not local_path_obj.exists() or not local_path_obj.is_dir():
logger.warning(f"Local path does not exist or is not a directory: {local_path_obj}")
return
logger.info(f"Copying local directory {local_path_obj} to container {container.id}")
tar_buffer = BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
for item in local_path_obj.rglob("*"):
if item.is_file():
arcname = item.relative_to(local_path_obj)
tar.add(item, arcname=arcname)
tar_buffer.seek(0)
container.put_archive("/shared_workspace", tar_buffer.getvalue())
container.exec_run(
"chown -R pentester:pentester /shared_workspace && chmod -R 755 /shared_workspace",
user="root",
)
logger.info(
f"Successfully copied {local_path_obj} to /shared_workspace in container "
f"{container.id}"
)
except (OSError, DockerException):
logger.exception("Failed to copy local directory to container")
async def create_sandbox(
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None
) -> SandboxInfo:
sandbox = self._get_sandbox_by_agent_id(agent_id)
auth_token = existing_token or self._generate_sandbox_token()
scan_id = self._get_scan_id(agent_id)
volume_name = self._get_workspace_volume_name(scan_id)
self._ensure_workspace_volume(volume_name)
if not sandbox:
logger.info("Creating new Docker sandbox for agent %s", agent_id)
try:
tool_server_port = self._find_available_port()
caido_port = self._find_available_port()
volumes_config = {volume_name: {"bind": "/shared_workspace", "mode": "rw"}}
container_name = f"strix-{agent_id}"
sandbox = self.client.containers.run(
STRIX_IMAGE,
command="sleep infinity",
detach=True,
name=container_name,
hostname=container_name,
ports={
f"{tool_server_port}/tcp": tool_server_port,
f"{caido_port}/tcp": caido_port,
},
cap_add=["NET_ADMIN", "NET_RAW"],
labels={
STRIX_AGENT_LABEL: agent_id,
STRIX_SCAN_LABEL: scan_id,
},
environment={
"PYTHONUNBUFFERED": "1",
"STRIX_AGENT_ID": agent_id,
"STRIX_SANDBOX_TOKEN": auth_token,
"STRIX_TOOL_SERVER_PORT": str(tool_server_port),
"CAIDO_PORT": str(caido_port),
},
volumes=volumes_config,
tty=True,
)
logger.info(
"Created new sandbox %s for agent %s with shared workspace %s",
sandbox.id,
agent_id,
volume_name,
)
except DockerException as e:
raise RuntimeError(f"Failed to create Docker sandbox: {e}") from e
assert sandbox is not None
if sandbox.status != "running":
sandbox.start()
time.sleep(15)
if local_source_path and volume_name not in _initialized_volumes:
self._copy_local_directory_to_container(sandbox, local_source_path)
_initialized_volumes.add(volume_name)
sandbox_id = sandbox.id
if sandbox_id is None:
raise RuntimeError("Docker container ID is unexpectedly None")
tool_server_port_str = sandbox.attrs["Config"]["Env"][
next(
(
i
for i, s in enumerate(sandbox.attrs["Config"]["Env"])
if s.startswith("STRIX_TOOL_SERVER_PORT=")
),
-1,
)
].split("=")[1]
tool_server_port = int(tool_server_port_str)
api_url = await self.get_sandbox_url(sandbox_id, tool_server_port)
return {
"workspace_id": sandbox_id,
"api_url": api_url,
"auth_token": auth_token,
"tool_server_port": tool_server_port,
}
async def get_sandbox_url(self, sandbox_id: str, port: int) -> str:
try:
container = self.client.containers.get(sandbox_id)
container.reload()
host = "localhost"
if "DOCKER_HOST" in os.environ:
docker_host = os.environ["DOCKER_HOST"]
if "://" in docker_host:
host = docker_host.split("://")[1].split(":")[0]
except NotFound:
raise ValueError(f"Sandbox {sandbox_id} not found.") from None
except DockerException as e:
raise RuntimeError(f"Failed to get sandbox URL for {sandbox_id}: {e}") from e
else:
return f"http://{host}:{port}"
async def destroy_sandbox(self, sandbox_id: str) -> None:
logger.info("Destroying Docker sandbox %s", sandbox_id)
try:
container = self.client.containers.get(sandbox_id)
scan_id = None
if container.labels and STRIX_SCAN_LABEL in container.labels:
scan_id = container.labels[STRIX_SCAN_LABEL]
container.stop()
container.remove()
logger.info("Successfully destroyed sandbox %s", sandbox_id)
if scan_id:
await self._cleanup_workspace_if_empty(scan_id)
except NotFound:
logger.warning("Sandbox %s not found for destruction.", sandbox_id)
except DockerException as e:
logger.warning("Failed to destroy sandbox %s: %s", sandbox_id, e)
async def _cleanup_workspace_if_empty(self, scan_id: str) -> None:
try:
volume_name = self._get_workspace_volume_name(scan_id)
containers = self.client.containers.list(
all=True, filters={"label": f"{STRIX_SCAN_LABEL}={scan_id}"}
)
if not containers:
try:
volume = self.client.volumes.get(volume_name)
volume.remove()
logger.info(
f"Cleaned up workspace volume {volume_name} for completed scan {scan_id}"
)
_initialized_volumes.discard(volume_name)
except NotFound:
logger.debug(f"Volume {volume_name} already removed")
except DockerException as e:
logger.warning(f"Failed to remove volume {volume_name}: {e}")
except DockerException as e:
logger.warning("Error during workspace cleanup for scan %s: %s", scan_id, e)
async def cleanup_scan_workspace(self, scan_id: str) -> None:
await self._cleanup_workspace_if_empty(scan_id)

25
strix/runtime/runtime.py Normal file
View File

@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from typing import TypedDict
class SandboxInfo(TypedDict):
workspace_id: str
api_url: str
auth_token: str | None
tool_server_port: int
class AbstractRuntime(ABC):
@abstractmethod
async def create_sandbox(
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None
) -> SandboxInfo:
raise NotImplementedError
@abstractmethod
async def get_sandbox_url(self, sandbox_id: str, port: int) -> str:
raise NotImplementedError
@abstractmethod
async def destroy_sandbox(self, sandbox_id: str) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,97 @@
import logging
import os
from typing import Any
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, ValidationError
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not SANDBOX_MODE:
raise RuntimeError("Tool server should only run in sandbox mode (STRIX_SANDBOX_MODE=true)")
EXPECTED_TOKEN = os.getenv("STRIX_SANDBOX_TOKEN")
if not EXPECTED_TOKEN:
raise RuntimeError("STRIX_SANDBOX_TOKEN environment variable is required in sandbox mode")
app = FastAPI()
logger = logging.getLogger(__name__)
security = HTTPBearer()
security_dependency = Depends(security)
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
if not credentials or credentials.scheme != "Bearer":
logger.warning("Authentication failed: Invalid or missing Bearer token scheme")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme. Bearer token required.",
headers={"WWW-Authenticate": "Bearer"},
)
if credentials.credentials != EXPECTED_TOKEN:
logger.warning("Authentication failed: Invalid token provided from remote host")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
logger.debug("Authentication successful for tool execution request")
return credentials.credentials
class ToolExecutionRequest(BaseModel):
tool_name: str
kwargs: dict[str, Any]
class ToolExecutionResponse(BaseModel):
result: Any | None = None
error: str | None = None
@app.post("/execute", response_model=ToolExecutionResponse)
async def execute_tool(
request: ToolExecutionRequest, credentials: HTTPAuthorizationCredentials = security_dependency
) -> ToolExecutionResponse:
verify_token(credentials)
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
from strix.tools.registry import get_tool_by_name
try:
tool_func = get_tool_by_name(request.tool_name)
if not tool_func:
return ToolExecutionResponse(error=f"Tool '{request.tool_name}' not found")
converted_kwargs = convert_arguments(tool_func, request.kwargs)
result = tool_func(**converted_kwargs)
return ToolExecutionResponse(result=result)
except (ArgumentConversionError, ValidationError) as e:
logger.warning("Invalid tool arguments: %s", e)
return ToolExecutionResponse(error=f"Invalid arguments: {e}")
except TypeError as e:
logger.warning("Tool execution type error: %s", e)
return ToolExecutionResponse(error=f"Tool execution error: {e}")
except ValueError as e:
logger.warning("Tool execution value error: %s", e)
return ToolExecutionResponse(error=f"Tool execution error: {e}")
except Exception:
logger.exception("Unexpected error during tool execution")
return ToolExecutionResponse(error="Internal server error")
@app.get("/health")
async def health_check() -> dict[str, str]:
return {
"status": "healthy",
"sandbox_mode": str(SANDBOX_MODE),
"environment": "sandbox" if SANDBOX_MODE else "main",
"auth_configured": "true" if EXPECTED_TOKEN else "false",
}