refactor: simplify container initialization and fix startup reliability

- Move tool server startup from Python to entrypoint script
- Hardcode Caido port (48080) in entrypoint, remove from Python
- Use /app/venv/bin/python directly instead of poetry run
- Fix env var passing through sudo with sudo -E and explicit vars
- Add Caido process monitoring and logging during startup
- Add retry logic with exponential backoff for token fetch
- Add tool server process validation before declaring ready
- Simplify docker_runtime.py (489 -> 310 lines)
- DRY up container state recovery into _recover_container_state()
- Add container creation retry logic (3 attempts)
- Fix GraphQL health check URL (/graphql/ with trailing slash)
This commit is contained in:
0xallam
2026-01-16 03:40:09 -08:00
committed by Ahmed Allam
parent c433d4ffb2
commit 61dea7010a
3 changed files with 187 additions and 316 deletions

View File

@@ -1,8 +1,11 @@
#!/bin/bash #!/bin/bash
set -e set -e
if [ -z "$CAIDO_PORT" ]; then CAIDO_PORT=48080
echo "Error: CAIDO_PORT must be set." CAIDO_LOG="/tmp/caido_startup.log"
if [ ! -f /app/certs/ca.p12 ]; then
echo "ERROR: CA certificate file /app/certs/ca.p12 not found."
exit 1 exit 1
fi fi
@@ -11,28 +14,62 @@ caido-cli --listen 127.0.0.1:${CAIDO_PORT} \
--no-logging \ --no-logging \
--no-open \ --no-open \
--import-ca-cert /app/certs/ca.p12 \ --import-ca-cert /app/certs/ca.p12 \
--import-ca-cert-pass "" > /dev/null 2>&1 & --import-ca-cert-pass "" > "$CAIDO_LOG" 2>&1 &
CAIDO_PID=$!
echo "Started Caido with PID $CAIDO_PID on port $CAIDO_PORT"
echo "Waiting for Caido API to be ready..." echo "Waiting for Caido API to be ready..."
CAIDO_READY=false
for i in {1..30}; do for i in {1..30}; do
if curl -s -o /dev/null http://localhost:${CAIDO_PORT}/graphql; then if ! kill -0 $CAIDO_PID 2>/dev/null; then
echo "Caido API is ready." echo "ERROR: Caido process died while waiting for API (iteration $i)."
echo "=== Caido log ==="
cat "$CAIDO_LOG" 2>/dev/null || echo "(no log available)"
exit 1
fi
if curl -s -o /dev/null -w "%{http_code}" http://localhost:${CAIDO_PORT}/graphql/ | grep -qE "^(200|400)$"; then
echo "Caido API is ready (attempt $i)."
CAIDO_READY=true
break break
fi fi
sleep 1 sleep 1
done done
if [ "$CAIDO_READY" = false ]; then
echo "ERROR: Caido API did not become ready within 30 seconds."
echo "Caido process status: $(kill -0 $CAIDO_PID 2>&1 && echo 'running' || echo 'dead')"
echo "=== Caido log ==="
cat "$CAIDO_LOG" 2>/dev/null || echo "(no log available)"
exit 1
fi
sleep 2 sleep 2
echo "Fetching API token..." echo "Fetching API token..."
TOKEN=$(curl -s -X POST \ TOKEN=""
for attempt in 1 2 3 4 5; do
RESPONSE=$(curl -sL -X POST \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{"query":"mutation LoginAsGuest { loginAsGuest { token { accessToken } } }"}' \ -d '{"query":"mutation LoginAsGuest { loginAsGuest { token { accessToken } } }"}' \
http://localhost:${CAIDO_PORT}/graphql | jq -r '.data.loginAsGuest.token.accessToken') http://localhost:${CAIDO_PORT}/graphql)
TOKEN=$(echo "$RESPONSE" | jq -r '.data.loginAsGuest.token.accessToken // empty')
if [ -n "$TOKEN" ] && [ "$TOKEN" != "null" ]; then
echo "Successfully obtained API token (attempt $attempt)."
break
fi
echo "Token fetch attempt $attempt failed: $RESPONSE"
sleep $((attempt * 2))
done
if [ -z "$TOKEN" ] || [ "$TOKEN" == "null" ]; then if [ -z "$TOKEN" ] || [ "$TOKEN" == "null" ]; then
echo "Failed to get API token from Caido." echo "ERROR: Failed to get API token from Caido after 5 attempts."
curl -s -X POST -H "Content-Type: application/json" -d '{"query":"mutation { loginAsGuest { token { accessToken } } }"}' http://localhost:${CAIDO_PORT}/graphql echo "=== Caido log ==="
cat "$CAIDO_LOG" 2>/dev/null || echo "(no log available)"
exit 1 exit 1
fi fi
@@ -40,7 +77,7 @@ export CAIDO_API_TOKEN=$TOKEN
echo "Caido API token has been set." echo "Caido API token has been set."
echo "Creating a new Caido project..." echo "Creating a new Caido project..."
CREATE_PROJECT_RESPONSE=$(curl -s -X POST \ CREATE_PROJECT_RESPONSE=$(curl -sL -X POST \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-H "Authorization: Bearer $TOKEN" \ -H "Authorization: Bearer $TOKEN" \
-d '{"query":"mutation CreateProject { createProject(input: {name: \"sandbox\", temporary: true}) { project { id } } }"}' \ -d '{"query":"mutation CreateProject { createProject(input: {name: \"sandbox\", temporary: true}) { project { id } } }"}' \
@@ -57,7 +94,7 @@ fi
echo "Caido project created with ID: $PROJECT_ID" echo "Caido project created with ID: $PROJECT_ID"
echo "Selecting Caido project..." echo "Selecting Caido project..."
SELECT_RESPONSE=$(curl -s -X POST \ SELECT_RESPONSE=$(curl -sL -X POST \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-H "Authorization: Bearer $TOKEN" \ -H "Authorization: Bearer $TOKEN" \
-d '{"query":"mutation SelectProject { selectProject(id: \"'$PROJECT_ID'\") { currentProject { project { id } } } }"}' \ -d '{"query":"mutation SelectProject { selectProject(id: \"'$PROJECT_ID'\") { currentProject { project { id } } } }"}' \
@@ -114,9 +151,33 @@ sudo -u pentester certutil -N -d sql:/home/pentester/.pki/nssdb --empty-password
sudo -u pentester certutil -A -n "Testing Root CA" -t "C,," -i /app/certs/ca.crt -d sql:/home/pentester/.pki/nssdb sudo -u pentester certutil -A -n "Testing Root CA" -t "C,," -i /app/certs/ca.crt -d sql:/home/pentester/.pki/nssdb
echo "✅ CA added to browser trust store" echo "✅ CA added to browser trust store"
echo "Container initialization complete - agents will start their own tool servers as needed" echo "Starting tool server..."
echo "✅ Shared container ready for multi-agent use" cd /app
TOOL_SERVER_TIMEOUT="${STRIX_SANDBOX_EXECUTION_TIMEOUT:-120}"
TOOL_SERVER_LOG="/tmp/tool_server.log"
sudo -E -u pentester \
PYTHONPATH=/app \
STRIX_SANDBOX_MODE=true \
TOOL_SERVER_TOKEN="$TOOL_SERVER_TOKEN" \
TOOL_SERVER_PORT="$TOOL_SERVER_PORT" \
TOOL_SERVER_TIMEOUT="$TOOL_SERVER_TIMEOUT" \
/app/venv/bin/python strix/runtime/tool_server.py \
--token="$TOOL_SERVER_TOKEN" \
--host=0.0.0.0 \
--port="$TOOL_SERVER_PORT" \
--timeout="$TOOL_SERVER_TIMEOUT" > "$TOOL_SERVER_LOG" 2>&1 &
sleep 3
if ! pgrep -f "tool_server.py" > /dev/null; then
echo "ERROR: Tool server process failed to start"
echo "=== Tool server log ==="
cat "$TOOL_SERVER_LOG" 2>/dev/null || echo "(no log)"
exit 1
fi
echo "✅ Tool server started on port $TOOL_SERVER_PORT"
echo "✅ Container ready"
cd /workspace cd /workspace
exec "$@" exec "$@"

View File

@@ -1,15 +1,13 @@
import contextlib import contextlib
import logging
import os import os
import secrets import secrets
import socket import socket
import time import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from pathlib import Path from pathlib import Path
from typing import Any, cast from typing import cast
import docker import docker
import httpx
from docker.errors import DockerException, ImageNotFound, NotFound from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container from docker.models.containers import Container
from requests.exceptions import ConnectionError as RequestsConnectionError from requests.exceptions import ConnectionError as RequestsConnectionError
@@ -22,10 +20,8 @@ from .runtime import AbstractRuntime, SandboxInfo
HOST_GATEWAY_HOSTNAME = "host.docker.internal" HOST_GATEWAY_HOSTNAME = "host.docker.internal"
DOCKER_TIMEOUT = 60 # seconds DOCKER_TIMEOUT = 60
TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request CONTAINER_TOOL_SERVER_PORT = 48081
TOOL_SERVER_HEALTH_RETRIES = 10 # number of retries for health check
logger = logging.getLogger(__name__)
class DockerRuntime(AbstractRuntime): class DockerRuntime(AbstractRuntime):
@@ -33,50 +29,20 @@ class DockerRuntime(AbstractRuntime):
try: try:
self.client = docker.from_env(timeout=DOCKER_TIMEOUT) self.client = docker.from_env(timeout=DOCKER_TIMEOUT)
except (DockerException, RequestsConnectionError, RequestsTimeout) as e: except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.exception("Failed to connect to Docker daemon")
if isinstance(e, RequestsConnectionError | RequestsTimeout):
raise SandboxInitializationError(
"Docker daemon unresponsive",
f"Connection timed out after {DOCKER_TIMEOUT} seconds. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from e
raise SandboxInitializationError( raise SandboxInitializationError(
"Docker is not available", "Docker is not available",
"Docker is not available or not configured correctly. " "Please ensure Docker Desktop is installed and running.",
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from e ) from e
self._scan_container: Container | None = None self._scan_container: Container | None = None
self._tool_server_port: int | None = None self._tool_server_port: int | None = None
self._tool_server_token: str | None = None self._tool_server_token: str | None = None
def _generate_sandbox_token(self) -> str:
return secrets.token_urlsafe(32)
def _find_available_port(self) -> int: def _find_available_port(self) -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) s.bind(("", 0))
return cast("int", s.getsockname()[1]) return cast("int", s.getsockname()[1])
def _exec_run_with_timeout(
self, container: Container, cmd: str, timeout: int = DOCKER_TIMEOUT, **kwargs: Any
) -> Any:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(container.exec_run, cmd, **kwargs)
try:
return future.result(timeout=timeout)
except FuturesTimeoutError:
logger.exception(f"exec_run timed out after {timeout}s: {cmd[:100]}...")
raise SandboxInitializationError(
"Container command timed out",
f"Command timed out after {timeout} seconds. "
"Docker may be overloaded or unresponsive. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from None
def _get_scan_id(self, agent_id: str) -> str: def _get_scan_id(self, agent_id: str) -> str:
try: try:
from strix.telemetry.tracer import get_global_tracer from strix.telemetry.tracer import get_global_tracer
@@ -84,129 +50,116 @@ class DockerRuntime(AbstractRuntime):
tracer = get_global_tracer() tracer = get_global_tracer()
if tracer and tracer.scan_config: if tracer and tracer.scan_config:
return str(tracer.scan_config.get("scan_id", "default-scan")) return str(tracer.scan_config.get("scan_id", "default-scan"))
except ImportError: except (ImportError, AttributeError):
logger.debug("Failed to import tracer, using fallback scan ID") pass
except AttributeError:
logger.debug("Tracer missing scan_config, using fallback scan ID")
return f"scan-{agent_id.split('-')[0]}" return f"scan-{agent_id.split('-')[0]}"
def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None: def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None:
def _validate_image(image: docker.models.images.Image) -> None:
if not image.id or not image.attrs:
raise ImageNotFound(f"Image {image_name} metadata incomplete")
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
image = self.client.images.get(image_name) image = self.client.images.get(image_name)
_validate_image(image) if not image.id or not image.attrs:
except ImageNotFound: raise ImageNotFound(f"Image {image_name} metadata incomplete") # noqa: TRY301
except (ImageNotFound, DockerException):
if attempt == max_retries - 1: if attempt == max_retries - 1:
logger.exception(f"Image {image_name} not found after {max_retries} attempts")
raise raise
logger.warning(f"Image {image_name} not ready, attempt {attempt + 1}/{max_retries}")
time.sleep(2**attempt)
except DockerException:
if attempt == max_retries - 1:
logger.exception(f"Failed to verify image {image_name}")
raise
logger.warning(f"Docker error verifying image, attempt {attempt + 1}/{max_retries}")
time.sleep(2**attempt) time.sleep(2**attempt)
else: else:
logger.debug(f"Image {image_name} verified as available")
return return
def _create_container_with_retry(self, scan_id: str, max_retries: int = 3) -> Container: def _recover_container_state(self, container: Container) -> None:
last_exception = None for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=", 1)[1]
break
port_bindings = container.attrs.get("NetworkSettings", {}).get("Ports", {})
port_key = f"{CONTAINER_TOOL_SERVER_PORT}/tcp"
if port_bindings.get(port_key):
self._tool_server_port = int(port_bindings[port_key][0]["HostPort"])
def _wait_for_tool_server(self, max_retries: int = 20, timeout: int = 5) -> None:
host = self._resolve_docker_host()
health_url = f"http://{host}:{self._tool_server_port}/health"
for attempt in range(max_retries):
try:
with httpx.Client(trust_env=False, timeout=timeout) as client:
response = client.get(health_url)
if response.status_code == 200:
data = response.json()
if data.get("status") == "healthy":
return
except (httpx.ConnectError, httpx.TimeoutException, httpx.RequestError):
pass
time.sleep(min(2**attempt * 0.5, 5))
raise SandboxInitializationError(
"Tool server failed to start",
"Container initialization timed out. Please try again.",
)
def _create_container(self, scan_id: str, max_retries: int = 2) -> Container:
container_name = f"strix-scan-{scan_id}" container_name = f"strix-scan-{scan_id}"
image_name = Config.get("strix_image") image_name = Config.get("strix_image")
if not image_name: if not image_name:
raise ValueError("STRIX_IMAGE must be configured") raise ValueError("STRIX_IMAGE must be configured")
for attempt in range(max_retries):
try:
self._verify_image_available(image_name) self._verify_image_available(image_name)
last_error: Exception | None = None
for attempt in range(max_retries + 1):
try: try:
existing_container = self.client.containers.get(container_name) with contextlib.suppress(NotFound):
logger.warning(f"Container {container_name} already exists, removing it") existing = self.client.containers.get(container_name)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
existing_container.stop(timeout=5) existing.stop(timeout=5)
existing_container.remove(force=True) existing.remove(force=True)
time.sleep(1) time.sleep(1)
except NotFound:
pass
except DockerException as e:
logger.warning(f"Error checking/removing existing container: {e}")
caido_port = self._find_available_port() self._tool_server_port = self._find_available_port()
tool_server_port = self._find_available_port() self._tool_server_token = secrets.token_urlsafe(32)
tool_server_token = self._generate_sandbox_token() execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120"
self._tool_server_port = tool_server_port
self._tool_server_token = tool_server_token
container = self.client.containers.run( container = self.client.containers.run(
image_name, image_name,
command="sleep infinity", command="sleep infinity",
detach=True, detach=True,
name=container_name, name=container_name,
hostname=f"strix-scan-{scan_id}", hostname=container_name,
ports={ ports={f"{CONTAINER_TOOL_SERVER_PORT}/tcp": self._tool_server_port},
f"{caido_port}/tcp": caido_port,
f"{tool_server_port}/tcp": tool_server_port,
},
cap_add=["NET_ADMIN", "NET_RAW"], cap_add=["NET_ADMIN", "NET_RAW"],
labels={"strix-scan-id": scan_id}, labels={"strix-scan-id": scan_id},
environment={ environment={
"PYTHONUNBUFFERED": "1", "PYTHONUNBUFFERED": "1",
"CAIDO_PORT": str(caido_port), "TOOL_SERVER_PORT": str(CONTAINER_TOOL_SERVER_PORT),
"TOOL_SERVER_PORT": str(tool_server_port), "TOOL_SERVER_TOKEN": self._tool_server_token,
"TOOL_SERVER_TOKEN": tool_server_token, "STRIX_SANDBOX_EXECUTION_TIMEOUT": str(execution_timeout),
"HOST_GATEWAY": HOST_GATEWAY_HOSTNAME, "HOST_GATEWAY": HOST_GATEWAY_HOSTNAME,
}, },
extra_hosts=self._get_extra_hosts(), extra_hosts={HOST_GATEWAY_HOSTNAME: "host-gateway"},
tty=True, tty=True,
) )
self._scan_container = container self._scan_container = container
logger.info("Created container %s for scan %s", container.id, scan_id) self._wait_for_tool_server()
self._initialize_container(
container, caido_port, tool_server_port, tool_server_token
)
except (DockerException, RequestsConnectionError, RequestsTimeout) as e: except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
last_exception = e last_error = e
if attempt == max_retries - 1: if attempt < max_retries:
logger.exception(f"Failed to create container after {max_retries} attempts")
break
logger.warning(f"Container creation attempt {attempt + 1}/{max_retries} failed")
self._tool_server_port = None self._tool_server_port = None
self._tool_server_token = None self._tool_server_token = None
time.sleep(2**attempt)
sleep_time = (2**attempt) + (0.1 * attempt)
time.sleep(sleep_time)
else: else:
return container return container
if isinstance(last_exception, RequestsConnectionError | RequestsTimeout):
raise SandboxInitializationError( raise SandboxInitializationError(
"Failed to create sandbox container", "Failed to create container",
f"Docker daemon unresponsive after {max_retries} attempts " f"Container creation failed after {max_retries + 1} attempts: {last_error}",
f"(timed out after {DOCKER_TIMEOUT}s). " ) from last_error
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from last_exception
raise SandboxInitializationError(
"Failed to create sandbox container",
f"Container creation failed after {max_retries} attempts: {last_exception}. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from last_exception
def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912 def _get_or_create_container(self, scan_id: str) -> Container:
container_name = f"strix-scan-{scan_id}" container_name = f"strix-scan-{scan_id}"
if self._scan_container: if self._scan_container:
@@ -223,33 +176,14 @@ class DockerRuntime(AbstractRuntime):
container = self.client.containers.get(container_name) container = self.client.containers.get(container_name)
container.reload() container.reload()
if (
"strix-scan-id" not in container.labels
or container.labels["strix-scan-id"] != scan_id
):
logger.warning(
f"Container {container_name} exists but missing/wrong label, updating"
)
if container.status != "running": if container.status != "running":
logger.info(f"Starting existing container {container_name}")
container.start() container.start()
time.sleep(2) time.sleep(2)
self._scan_container = container self._scan_container = container
self._recover_container_state(container)
for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_PORT="):
self._tool_server_port = int(env_var.split("=")[1])
elif env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=")[1]
logger.info(f"Reusing existing container {container_name}")
except NotFound: except NotFound:
pass pass
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.warning(f"Failed to get container by name {container_name}: {e}")
else: else:
return container return container
@@ -262,102 +196,14 @@ class DockerRuntime(AbstractRuntime):
if container.status != "running": if container.status != "running":
container.start() container.start()
time.sleep(2) time.sleep(2)
self._scan_container = container self._scan_container = container
self._recover_container_state(container)
for env_var in container.attrs["Config"]["Env"]:
if env_var.startswith("TOOL_SERVER_PORT="):
self._tool_server_port = int(env_var.split("=")[1])
elif env_var.startswith("TOOL_SERVER_TOKEN="):
self._tool_server_token = env_var.split("=")[1]
logger.info(f"Found existing container by label for scan {scan_id}")
return container return container
except (DockerException, RequestsConnectionError, RequestsTimeout) as e: except DockerException:
logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e) pass
logger.info("Creating new Docker container for scan %s", scan_id) return self._create_container(scan_id)
return self._create_container_with_retry(scan_id)
def _initialize_container(
self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str
) -> None:
logger.info("Initializing Caido proxy on port %s", caido_port)
self._exec_run_with_timeout(
container,
f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'",
detach=False,
)
time.sleep(5)
result = self._exec_run_with_timeout(
container,
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'",
user="pentester",
)
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120"
container.exec_run(
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
f"--host 0.0.0.0 --port {tool_server_port} --timeout {execution_timeout} &'",
detach=True,
user="pentester",
)
time.sleep(2)
host = self._resolve_docker_host()
health_url = f"http://{host}:{tool_server_port}/health"
self._wait_for_tool_server_health(health_url)
def _wait_for_tool_server_health(
self,
health_url: str,
max_retries: int = TOOL_SERVER_HEALTH_RETRIES,
request_timeout: int = TOOL_SERVER_HEALTH_REQUEST_TIMEOUT,
) -> None:
import httpx
logger.info(f"Waiting for tool server health at {health_url}")
for attempt in range(max_retries):
try:
with httpx.Client(trust_env=False, timeout=request_timeout) as client:
response = client.get(health_url)
response.raise_for_status()
health_data = response.json()
if health_data.get("status") == "healthy":
logger.info(
f"Tool server is healthy after {attempt + 1} attempt(s): {health_data}"
)
return
logger.warning(f"Tool server returned unexpected status: {health_data}")
except httpx.ConnectError:
logger.debug(
f"Tool server not ready (attempt {attempt + 1}/{max_retries}): "
f"Connection refused"
)
except httpx.TimeoutException:
logger.debug(
f"Tool server not ready (attempt {attempt + 1}/{max_retries}): "
f"Request timed out"
)
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.debug(f"Tool server not ready (attempt {attempt + 1}/{max_retries}): {e}")
sleep_time = min(2**attempt * 0.5, 5)
time.sleep(sleep_time)
raise SandboxInitializationError(
"Tool server failed to start",
"Please ensure Docker Desktop is installed and running, and try running strix again.",
)
def _copy_local_directory_to_container( def _copy_local_directory_to_container(
self, container: Container, local_path: str, target_name: str | None = None self, container: Container, local_path: str, target_name: str | None = None
@@ -368,17 +214,8 @@ class DockerRuntime(AbstractRuntime):
try: try:
local_path_obj = Path(local_path).resolve() local_path_obj = Path(local_path).resolve()
if not local_path_obj.exists() or not local_path_obj.is_dir(): if not local_path_obj.exists() or not local_path_obj.is_dir():
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
return return
if target_name:
logger.info(
f"Copying local directory {local_path_obj} to container at "
f"/workspace/{target_name}"
)
else:
logger.info(f"Copying local directory {local_path_obj} to container")
tar_buffer = BytesIO() tar_buffer = BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w") as tar: with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
for item in local_path_obj.rglob("*"): for item in local_path_obj.rglob("*"):
@@ -389,16 +226,12 @@ class DockerRuntime(AbstractRuntime):
tar_buffer.seek(0) tar_buffer.seek(0)
container.put_archive("/workspace", tar_buffer.getvalue()) container.put_archive("/workspace", tar_buffer.getvalue())
container.exec_run( container.exec_run(
"chown -R pentester:pentester /workspace && chmod -R 755 /workspace", "chown -R pentester:pentester /workspace && chmod -R 755 /workspace",
user="root", user="root",
) )
logger.info("Successfully copied local directory to /workspace")
except (OSError, DockerException): except (OSError, DockerException):
logger.exception("Failed to copy local directory to container") pass
async def create_sandbox( async def create_sandbox(
self, self,
@@ -407,7 +240,7 @@ class DockerRuntime(AbstractRuntime):
local_sources: list[dict[str, str]] | None = None, local_sources: list[dict[str, str]] | None = None,
) -> SandboxInfo: ) -> SandboxInfo:
scan_id = self._get_scan_id(agent_id) scan_id = self._get_scan_id(agent_id)
container = self._get_or_create_scan_container(scan_id) container = self._get_or_create_container(scan_id)
source_copied_key = f"_source_copied_{scan_id}" source_copied_key = f"_source_copied_{scan_id}"
if local_sources and not hasattr(self, source_copied_key): if local_sources and not hasattr(self, source_copied_key):
@@ -415,40 +248,33 @@ class DockerRuntime(AbstractRuntime):
source_path = source.get("source_path") source_path = source.get("source_path")
if not source_path: if not source_path:
continue continue
target_name = (
target_name = source.get("workspace_subdir") source.get("workspace_subdir") or Path(source_path).name or f"target_{index}"
if not target_name: )
target_name = Path(source_path).name or f"target_{index}"
self._copy_local_directory_to_container(container, source_path, target_name) self._copy_local_directory_to_container(container, source_path, target_name)
setattr(self, source_copied_key, True) setattr(self, source_copied_key, True)
container_id = container.id if container.id is None:
if container_id is None:
raise RuntimeError("Docker container ID is unexpectedly None") raise RuntimeError("Docker container ID is unexpectedly None")
token = existing_token if existing_token is not None else self._tool_server_token token = existing_token or self._tool_server_token
if self._tool_server_port is None or token is None: if self._tool_server_port is None or token is None:
raise RuntimeError("Tool server not initialized or no token available") raise RuntimeError("Tool server not initialized")
api_url = await self.get_sandbox_url(container_id, self._tool_server_port) host = self._resolve_docker_host()
api_url = f"http://{host}:{self._tool_server_port}"
await self._register_agent_with_tool_server(api_url, agent_id, token) await self._register_agent(api_url, agent_id, token)
return { return {
"workspace_id": container_id, "workspace_id": container.id,
"api_url": api_url, "api_url": api_url,
"auth_token": token, "auth_token": token,
"tool_server_port": self._tool_server_port, "tool_server_port": self._tool_server_port,
"agent_id": agent_id, "agent_id": agent_id,
} }
async def _register_agent_with_tool_server( async def _register_agent(self, api_url: str, agent_id: str, token: str) -> None:
self, api_url: str, agent_id: str, token: str
) -> None:
import httpx
try: try:
async with httpx.AsyncClient(trust_env=False) as client: async with httpx.AsyncClient(trust_env=False) as client:
response = await client.post( response = await client.post(
@@ -458,54 +284,33 @@ class DockerRuntime(AbstractRuntime):
timeout=30, timeout=30,
) )
response.raise_for_status() response.raise_for_status()
logger.info(f"Registered agent {agent_id} with tool server") except httpx.RequestError:
except (httpx.RequestError, httpx.HTTPStatusError) as e: pass
logger.warning(f"Failed to register agent {agent_id}: {e}")
async def get_sandbox_url(self, container_id: str, port: int) -> str: async def get_sandbox_url(self, container_id: str, port: int) -> str:
try: try:
container = self.client.containers.get(container_id) self.client.containers.get(container_id)
container.reload() return f"http://{self._resolve_docker_host()}:{port}"
host = self._resolve_docker_host()
except NotFound: except NotFound:
raise ValueError(f"Container {container_id} not found.") from None raise ValueError(f"Container {container_id} not found.") from None
except DockerException as e:
raise RuntimeError(f"Failed to get container URL for {container_id}: {e}") from e
else:
return f"http://{host}:{port}"
def _resolve_docker_host(self) -> str: def _resolve_docker_host(self) -> str:
docker_host = os.getenv("DOCKER_HOST", "") docker_host = os.getenv("DOCKER_HOST", "")
if not docker_host: if docker_host:
return "127.0.0.1"
from urllib.parse import urlparse from urllib.parse import urlparse
parsed = urlparse(docker_host) parsed = urlparse(docker_host)
if parsed.scheme in ("tcp", "http", "https") and parsed.hostname: if parsed.scheme in ("tcp", "http", "https") and parsed.hostname:
return parsed.hostname return parsed.hostname
return "127.0.0.1" return "127.0.0.1"
def _get_extra_hosts(self) -> dict[str, str]:
return {HOST_GATEWAY_HOSTNAME: "host-gateway"}
async def destroy_sandbox(self, container_id: str) -> None: async def destroy_sandbox(self, container_id: str) -> None:
logger.info("Destroying scan container %s", container_id)
try: try:
container = self.client.containers.get(container_id) container = self.client.containers.get(container_id)
container.stop() container.stop()
container.remove() container.remove()
logger.info("Successfully destroyed container %s", container_id)
self._scan_container = None self._scan_container = None
self._tool_server_port = None self._tool_server_port = None
self._tool_server_token = None self._tool_server_token = None
except (NotFound, DockerException):
except NotFound: pass
logger.warning("Container %s not found for destruction.", container_id)
except DockerException as e:
logger.warning("Failed to destroy container %s: %s", container_id, e)

View File

@@ -16,12 +16,17 @@ if TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Callable
CAIDO_PORT = 48080 # Fixed port inside container
class ProxyManager: class ProxyManager:
def __init__(self, auth_token: str | None = None): def __init__(self, auth_token: str | None = None):
host = "127.0.0.1" host = "127.0.0.1"
port = os.getenv("CAIDO_PORT", "56789") self.base_url = f"http://{host}:{CAIDO_PORT}/graphql"
self.base_url = f"http://{host}:{port}/graphql" self.proxies = {
self.proxies = {"http": f"http://{host}:{port}", "https": f"http://{host}:{port}"} "http": f"http://{host}:{CAIDO_PORT}",
"https": f"http://{host}:{CAIDO_PORT}",
}
self.auth_token = auth_token or os.getenv("CAIDO_API_TOKEN") self.auth_token = auth_token or os.getenv("CAIDO_API_TOKEN")
self.transport = RequestsHTTPTransport( self.transport = RequestsHTTPTransport(
url=self.base_url, headers={"Authorization": f"Bearer {self.auth_token}"} url=self.base_url, headers={"Authorization": f"Bearer {self.auth_token}"}