Files
strix/strix/runtime/docker_runtime.py
0xallam 8413987fcd feat: remove docker container on shutdown
Add automatic cleanup of Docker containers when the application exits.
Uses a singleton runtime pattern and spawns a detached subprocess for
cleanup to ensure fast exit without blocking the UI.
2026-01-19 18:26:41 -08:00

338 lines
12 KiB
Python

import contextlib
import os
import secrets
import socket
import time
from pathlib import Path
from typing import cast
import docker
import httpx
from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container
from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.exceptions import Timeout as RequestsTimeout
from strix.config import Config
from . import SandboxInitializationError
from .runtime import AbstractRuntime, SandboxInfo
HOST_GATEWAY_HOSTNAME = "host.docker.internal"
DOCKER_TIMEOUT = 60
CONTAINER_TOOL_SERVER_PORT = 48081
class DockerRuntime(AbstractRuntime):
def __init__(self) -> None:
try:
self.client = docker.from_env(timeout=DOCKER_TIMEOUT)
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
raise SandboxInitializationError(
"Docker is not available",
"Please ensure Docker Desktop is installed and running.",
) from e
self._scan_container: Container | None = None
self._tool_server_port: int | None = None
self._tool_server_token: str | None = None
def _find_available_port(self) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return cast("int", s.getsockname()[1])
def _get_scan_id(self, agent_id: str) -> str:
try:
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
if tracer and tracer.scan_config:
return str(tracer.scan_config.get("scan_id", "default-scan"))
except (ImportError, AttributeError):
pass
return f"scan-{agent_id.split('-')[0]}"
def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None:
for attempt in range(max_retries):
try:
image = self.client.images.get(image_name)
if not image.id or not image.attrs:
raise ImageNotFound(f"Image {image_name} metadata incomplete") # noqa: TRY301
except (ImageNotFound, DockerException):
if attempt == max_retries - 1:
raise
time.sleep(2**attempt)
else:
return
def _recover_container_state(self, container: Container) -> None:
for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=", 1)[1]
break
port_bindings = container.attrs.get("NetworkSettings", {}).get("Ports", {})
port_key = f"{CONTAINER_TOOL_SERVER_PORT}/tcp"
if port_bindings.get(port_key):
self._tool_server_port = int(port_bindings[port_key][0]["HostPort"])
def _wait_for_tool_server(self, max_retries: int = 30, timeout: int = 5) -> None:
host = self._resolve_docker_host()
health_url = f"http://{host}:{self._tool_server_port}/health"
time.sleep(5)
for attempt in range(max_retries):
try:
with httpx.Client(trust_env=False, timeout=timeout) as client:
response = client.get(health_url)
if response.status_code == 200:
data = response.json()
if data.get("status") == "healthy":
return
except (httpx.ConnectError, httpx.TimeoutException, httpx.RequestError):
pass
time.sleep(min(2**attempt * 0.5, 5))
raise SandboxInitializationError(
"Tool server failed to start",
"Container initialization timed out. Please try again.",
)
def _create_container(self, scan_id: str, max_retries: int = 2) -> Container:
container_name = f"strix-scan-{scan_id}"
image_name = Config.get("strix_image")
if not image_name:
raise ValueError("STRIX_IMAGE must be configured")
self._verify_image_available(image_name)
last_error: Exception | None = None
for attempt in range(max_retries + 1):
try:
with contextlib.suppress(NotFound):
existing = self.client.containers.get(container_name)
with contextlib.suppress(Exception):
existing.stop(timeout=5)
existing.remove(force=True)
time.sleep(1)
self._tool_server_port = self._find_available_port()
self._tool_server_token = secrets.token_urlsafe(32)
execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120"
container = self.client.containers.run(
image_name,
command="sleep infinity",
detach=True,
name=container_name,
hostname=container_name,
ports={f"{CONTAINER_TOOL_SERVER_PORT}/tcp": self._tool_server_port},
cap_add=["NET_ADMIN", "NET_RAW"],
labels={"strix-scan-id": scan_id},
environment={
"PYTHONUNBUFFERED": "1",
"TOOL_SERVER_PORT": str(CONTAINER_TOOL_SERVER_PORT),
"TOOL_SERVER_TOKEN": self._tool_server_token,
"STRIX_SANDBOX_EXECUTION_TIMEOUT": str(execution_timeout),
"HOST_GATEWAY": HOST_GATEWAY_HOSTNAME,
},
extra_hosts={HOST_GATEWAY_HOSTNAME: "host-gateway"},
tty=True,
)
self._scan_container = container
self._wait_for_tool_server()
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
last_error = e
if attempt < max_retries:
self._tool_server_port = None
self._tool_server_token = None
time.sleep(2**attempt)
else:
return container
raise SandboxInitializationError(
"Failed to create container",
f"Container creation failed after {max_retries + 1} attempts: {last_error}",
) from last_error
def _get_or_create_container(self, scan_id: str) -> Container:
container_name = f"strix-scan-{scan_id}"
if self._scan_container:
try:
self._scan_container.reload()
if self._scan_container.status == "running":
return self._scan_container
except NotFound:
self._scan_container = None
self._tool_server_port = None
self._tool_server_token = None
try:
container = self.client.containers.get(container_name)
container.reload()
if container.status != "running":
container.start()
time.sleep(2)
self._scan_container = container
self._recover_container_state(container)
except NotFound:
pass
else:
return container
try:
containers = self.client.containers.list(
all=True, filters={"label": f"strix-scan-id={scan_id}"}
)
if containers:
container = containers[0]
if container.status != "running":
container.start()
time.sleep(2)
self._scan_container = container
self._recover_container_state(container)
return container
except DockerException:
pass
return self._create_container(scan_id)
def _copy_local_directory_to_container(
self, container: Container, local_path: str, target_name: str | None = None
) -> None:
import tarfile
from io import BytesIO
try:
local_path_obj = Path(local_path).resolve()
if not local_path_obj.exists() or not local_path_obj.is_dir():
return
tar_buffer = BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
for item in local_path_obj.rglob("*"):
if item.is_file():
rel_path = item.relative_to(local_path_obj)
arcname = Path(target_name) / rel_path if target_name else rel_path
tar.add(item, arcname=arcname)
tar_buffer.seek(0)
container.put_archive("/workspace", tar_buffer.getvalue())
container.exec_run(
"chown -R pentester:pentester /workspace && chmod -R 755 /workspace",
user="root",
)
except (OSError, DockerException):
pass
async def create_sandbox(
self,
agent_id: str,
existing_token: str | None = None,
local_sources: list[dict[str, str]] | None = None,
) -> SandboxInfo:
scan_id = self._get_scan_id(agent_id)
container = self._get_or_create_container(scan_id)
source_copied_key = f"_source_copied_{scan_id}"
if local_sources and not hasattr(self, source_copied_key):
for index, source in enumerate(local_sources, start=1):
source_path = source.get("source_path")
if not source_path:
continue
target_name = (
source.get("workspace_subdir") or Path(source_path).name or f"target_{index}"
)
self._copy_local_directory_to_container(container, source_path, target_name)
setattr(self, source_copied_key, True)
if container.id is None:
raise RuntimeError("Docker container ID is unexpectedly None")
token = existing_token or self._tool_server_token
if self._tool_server_port is None or token is None:
raise RuntimeError("Tool server not initialized")
host = self._resolve_docker_host()
api_url = f"http://{host}:{self._tool_server_port}"
await self._register_agent(api_url, agent_id, token)
return {
"workspace_id": container.id,
"api_url": api_url,
"auth_token": token,
"tool_server_port": self._tool_server_port,
"agent_id": agent_id,
}
async def _register_agent(self, api_url: str, agent_id: str, token: str) -> None:
try:
async with httpx.AsyncClient(trust_env=False) as client:
response = await client.post(
f"{api_url}/register_agent",
params={"agent_id": agent_id},
headers={"Authorization": f"Bearer {token}"},
timeout=30,
)
response.raise_for_status()
except httpx.RequestError:
pass
async def get_sandbox_url(self, container_id: str, port: int) -> str:
try:
self.client.containers.get(container_id)
return f"http://{self._resolve_docker_host()}:{port}"
except NotFound:
raise ValueError(f"Container {container_id} not found.") from None
def _resolve_docker_host(self) -> str:
docker_host = os.getenv("DOCKER_HOST", "")
if docker_host:
from urllib.parse import urlparse
parsed = urlparse(docker_host)
if parsed.scheme in ("tcp", "http", "https") and parsed.hostname:
return parsed.hostname
return "127.0.0.1"
async def destroy_sandbox(self, container_id: str) -> None:
try:
container = self.client.containers.get(container_id)
container.stop()
container.remove()
self._scan_container = None
self._tool_server_port = None
self._tool_server_token = None
except (NotFound, DockerException):
pass
def cleanup(self) -> None:
if self._scan_container is not None:
container_name = self._scan_container.name
self._scan_container = None
self._tool_server_port = None
self._tool_server_token = None
if container_name is None:
return
import subprocess
subprocess.Popen( # noqa: S603
["docker", "rm", "-f", container_name], # noqa: S607
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)