diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index 2f53804..5b8a836 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import logging import os @@ -22,6 +23,8 @@ from .runtime import AbstractRuntime, SandboxInfo STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10") HOST_GATEWAY_HOSTNAME = "host.docker.internal" DOCKER_TIMEOUT = 60 # seconds +TOOL_SERVER_HEALTH_TIMEOUT = 30 # seconds to wait for tool server to be healthy +TOOL_SERVER_HEALTH_RETRIES = 10 # number of retries for health check logger = logging.getLogger(__name__) @@ -300,7 +303,66 @@ class DockerRuntime(AbstractRuntime): user="pentester", ) - time.sleep(5) + time.sleep(2) + + self._wait_for_tool_server_health(tool_server_port) + + def _wait_for_tool_server_health( + self, + port: int, + max_retries: int = TOOL_SERVER_HEALTH_RETRIES, + timeout: int = TOOL_SERVER_HEALTH_TIMEOUT, + ) -> None: + import httpx + + host = self._resolve_docker_host() + health_url = f"http://{host}:{port}/health" + + logger.info(f"Waiting for tool server health at {health_url}") + + last_error: Exception | None = None + for attempt in range(max_retries): + try: + with httpx.Client(trust_env=False, timeout=timeout / max_retries) as client: + response = client.get(health_url) + response.raise_for_status() + health_data = response.json() + + if health_data.get("status") == "healthy": + logger.info( + f"Tool server is healthy after {attempt + 1} attempt(s): {health_data}" + ) + return + + logger.warning(f"Tool server returned unexpected status: {health_data}") + + except httpx.ConnectError as e: + last_error = e + logger.debug( + f"Tool server not ready (attempt {attempt + 1}/{max_retries}): " + f"Connection refused" + ) + except httpx.TimeoutException as e: + last_error = e + logger.debug( + f"Tool server not ready (attempt {attempt + 1}/{max_retries}): " + f"Request timed out" + ) + except (httpx.RequestError, httpx.HTTPStatusError) as e: + last_error = e + logger.debug(f"Tool server not ready (attempt {attempt + 1}/{max_retries}): {e}") + + sleep_time = min(2**attempt * 0.5, 5) + time.sleep(sleep_time) + + error_detail = str(last_error) if last_error else "Unknown error" + raise SandboxInitializationError( + "Tool server failed to start", + f"Could not connect to tool server at {health_url} after {max_retries} attempts. " + f"Last error: {error_detail}. " + "Please ensure Docker Desktop is installed and running, " + "and try running strix again.", + ) def _copy_local_directory_to_container( self, container: Container, local_path: str, target_name: str | None = None @@ -377,6 +439,7 @@ class DockerRuntime(AbstractRuntime): api_url = await self.get_sandbox_url(container_id, self._tool_server_port) + await self._verify_tool_server_health(api_url) await self._register_agent_with_tool_server(api_url, agent_id, token) return { @@ -387,6 +450,60 @@ class DockerRuntime(AbstractRuntime): "agent_id": agent_id, } + async def _verify_tool_server_health( + self, + api_url: str, + max_retries: int = 3, + timeout: int = 10, + ) -> None: + import httpx + + health_url = f"{api_url}/health" + last_error: Exception | None = None + + for attempt in range(max_retries): + try: + async with httpx.AsyncClient(trust_env=False, timeout=timeout) as client: + response = await client.get(health_url) + response.raise_for_status() + health_data = response.json() + + if health_data.get("status") == "healthy": + logger.debug(f"Tool server health verified: {health_data}") + return + + logger.warning(f"Tool server returned unexpected status: {health_data}") + + except httpx.ConnectError as e: + last_error = e + logger.debug( + f"Tool server health check failed (attempt {attempt + 1}/{max_retries}): " + f"Connection refused" + ) + except httpx.TimeoutException as e: + last_error = e + logger.debug( + f"Tool server health check failed (attempt {attempt + 1}/{max_retries}): " + f"Request timed out" + ) + except (httpx.RequestError, httpx.HTTPStatusError) as e: + last_error = e + logger.debug( + f"Tool server health check failed (attempt {attempt + 1}/{max_retries}): {e}" + ) + + if attempt < max_retries - 1: + await asyncio.sleep(min(2**attempt, 4)) + + error_detail = str(last_error) if last_error else "Unknown error" + raise SandboxInitializationError( + "Tool server is not responding", + f"Could not connect to tool server at {health_url}. " + f"Last error: {error_detail}. " + "Please ensure Docker Desktop is installed and running, " + "and try running strix again.", + ) + async def _register_agent_with_tool_server( self, api_url: str, agent_id: str, token: str ) -> None: