From 740fb3ed407f97850caa83332b0ba57a255e481b Mon Sep 17 00:00:00 2001 From: 0xallam Date: Thu, 8 Jan 2026 16:11:15 -0800 Subject: [PATCH] fix: add timeout handling for Docker operations and improve error messages - Add SandboxInitializationError exception for sandbox/Docker failures - Add 60-second timeout to Docker client initialization - Add _exec_run_with_timeout() method using ThreadPoolExecutor for exec_run calls - Catch ConnectionError and Timeout exceptions from requests library - Add _handle_sandbox_error() and _handle_llm_error() methods in base_agent.py - Handle sandbox_error_details tool in TUI for displaying errors - Increase TUI truncation limits for better error visibility - Update all Docker error messages with helpful hint: 'Please ensure Docker Desktop is installed and running, and try running strix again.' --- strix/agents/base_agent.py | 146 +++++++++++++++++++++----------- strix/interface/tui.py | 36 +++++--- strix/interface/utils.py | 7 +- strix/runtime/__init__.py | 11 ++- strix/runtime/docker_runtime.py | 73 +++++++++++++--- 5 files changed, 193 insertions(+), 80 deletions(-) diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 2cfd37b..171a48e 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -16,6 +16,7 @@ from jinja2 import ( from strix.llm import LLM, LLMConfig, LLMRequestFailedError from strix.llm.utils import clean_content +from strix.runtime import SandboxInitializationError from strix.tools import process_tool_invocations from .state import AgentState @@ -145,18 +146,16 @@ class BaseAgent(metaclass=AgentMeta): if self.state.parent_id is None and agents_graph_actions._root_agent_id is None: agents_graph_actions._root_agent_id = self.state.agent_id - def cancel_current_execution(self) -> None: - if self._current_task and not self._current_task.done(): - self._current_task.cancel() - self._current_task = None - async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915 - await self._initialize_sandbox_and_state(task) - from strix.telemetry.tracer import get_global_tracer tracer = get_global_tracer() + try: + await self._initialize_sandbox_and_state(task) + except SandboxInitializationError as e: + return self._handle_sandbox_error(e, tracer) + while True: self._check_agent_messages(self.state) @@ -232,37 +231,9 @@ class BaseAgent(metaclass=AgentMeta): continue except LLMRequestFailedError as e: - error_msg = str(e) - error_details = getattr(e, "details", None) - self.state.add_error(error_msg) - - if self.non_interactive: - self.state.set_completed({"success": False, "error": error_msg}) - if tracer: - tracer.update_agent_status(self.state.agent_id, "failed", error_msg) - if error_details: - tracer.log_tool_execution_start( - self.state.agent_id, - "llm_error_details", - {"error": error_msg, "details": error_details}, - ) - tracer.update_tool_execution( - tracer._next_execution_id - 1, "failed", error_details - ) - return {"success": False, "error": error_msg} - - self.state.enter_waiting_state(llm_failed=True) - if tracer: - tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg) - if error_details: - tracer.log_tool_execution_start( - self.state.agent_id, - "llm_error_details", - {"error": error_msg, "details": error_details}, - ) - tracer.update_tool_execution( - tracer._next_execution_id - 1, "failed", error_details - ) + result = self._handle_llm_error(e, tracer) + if result is not None: + return result continue except (RuntimeError, ValueError, TypeError) as e: @@ -439,18 +410,6 @@ class BaseAgent(metaclass=AgentMeta): return False - async def _handle_iteration_error( - self, - error: RuntimeError | ValueError | TypeError | asyncio.CancelledError, - tracer: Optional["Tracer"], - ) -> bool: - error_msg = f"Error in iteration {self.state.iteration}: {error!s}" - logger.exception(error_msg) - self.state.add_error(error_msg) - if tracer: - tracer.update_agent_status(self.state.agent_id, "error") - return True - def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912 try: from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages @@ -535,3 +494,90 @@ class BaseAgent(metaclass=AgentMeta): logger = logging.getLogger(__name__) logger.warning(f"Error checking agent messages: {e}") return + + def _handle_sandbox_error( + self, + error: SandboxInitializationError, + tracer: Optional["Tracer"], + ) -> dict[str, Any]: + error_msg = str(error.message) + error_details = error.details + self.state.add_error(error_msg) + + if self.non_interactive: + self.state.set_completed({"success": False, "error": error_msg}) + if tracer: + tracer.update_agent_status(self.state.agent_id, "failed", error_msg) + if error_details: + exec_id = tracer.log_tool_execution_start( + self.state.agent_id, + "sandbox_error_details", + {"error": error_msg, "details": error_details}, + ) + tracer.update_tool_execution(exec_id, "failed", {"details": error_details}) + return {"success": False, "error": error_msg, "details": error_details} + + self.state.enter_waiting_state() + if tracer: + tracer.update_agent_status(self.state.agent_id, "sandbox_failed", error_msg) + if error_details: + exec_id = tracer.log_tool_execution_start( + self.state.agent_id, + "sandbox_error_details", + {"error": error_msg, "details": error_details}, + ) + tracer.update_tool_execution(exec_id, "failed", {"details": error_details}) + + return {"success": False, "error": error_msg, "details": error_details} + + def _handle_llm_error( + self, + error: LLMRequestFailedError, + tracer: Optional["Tracer"], + ) -> dict[str, Any] | None: + error_msg = str(error) + error_details = getattr(error, "details", None) + self.state.add_error(error_msg) + + if self.non_interactive: + self.state.set_completed({"success": False, "error": error_msg}) + if tracer: + tracer.update_agent_status(self.state.agent_id, "failed", error_msg) + if error_details: + exec_id = tracer.log_tool_execution_start( + self.state.agent_id, + "llm_error_details", + {"error": error_msg, "details": error_details}, + ) + tracer.update_tool_execution(exec_id, "failed", {"details": error_details}) + return {"success": False, "error": error_msg} + + self.state.enter_waiting_state(llm_failed=True) + if tracer: + tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg) + if error_details: + exec_id = tracer.log_tool_execution_start( + self.state.agent_id, + "llm_error_details", + {"error": error_msg, "details": error_details}, + ) + tracer.update_tool_execution(exec_id, "failed", {"details": error_details}) + + return None + + async def _handle_iteration_error( + self, + error: RuntimeError | ValueError | TypeError | asyncio.CancelledError, + tracer: Optional["Tracer"], + ) -> bool: + error_msg = f"Error in iteration {self.state.iteration}: {error!s}" + logger.exception(error_msg) + self.state.add_error(error_msg) + if tracer: + tracer.update_agent_status(self.state.agent_id, "error") + return True + + def cancel_current_execution(self) -> None: + if self._current_task and not self._current_task.done(): + self._current_task.cancel() + self._current_task = None diff --git a/strix/interface/tui.py b/strix/interface/tui.py index 27ed2d5..39ddefb 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -1629,15 +1629,8 @@ class StrixTUIApp(App): # type: ignore[misc] text = Text() - if tool_name == "llm_error_details": - text.append("✗ LLM Request Failed", style="red") - if args.get("details"): - details = str(args["details"]) - if len(details) > 300: - details = details[:297] + "..." - text.append("\nDetails: ", style="dim") - text.append(details) - return text + if tool_name in ("llm_error_details", "sandbox_error_details"): + return self._render_error_details(text, tool_name, args) text.append("→ Using tool ") text.append(tool_name, style="bold blue") @@ -1653,10 +1646,10 @@ class StrixTUIApp(App): # type: ignore[misc] text.append(icon, style=style) if args: - for k, v in list(args.items())[:2]: + for k, v in list(args.items())[:5]: str_v = str(v) - if len(str_v) > 80: - str_v = str_v[:77] + "..." + if len(str_v) > 500: + str_v = str_v[:497] + "..." text.append("\n ") text.append(k, style="dim") text.append(": ") @@ -1664,14 +1657,29 @@ class StrixTUIApp(App): # type: ignore[misc] if status in ["completed", "failed", "error"] and result: result_str = str(result) - if len(result_str) > 150: - result_str = result_str[:147] + "..." + if len(result_str) > 1000: + result_str = result_str[:997] + "..." text.append("\n") text.append("Result: ", style="bold") text.append(result_str) return text + def _render_error_details(self, text: Any, tool_name: str, args: dict[str, Any]) -> Any: + if tool_name == "llm_error_details": + text.append("✗ LLM Request Failed", style="red") + else: + text.append("✗ Sandbox Initialization Failed", style="red") + if args.get("error"): + text.append(f"\n{args['error']}", style="bold red") + if args.get("details"): + details = str(args["details"]) + if len(details) > 1000: + details = details[:997] + "..." + text.append("\nDetails: ", style="dim") + text.append(details) + return text + @on(Tree.NodeHighlighted) # type: ignore[misc] def handle_tree_highlight(self, event: Tree.NodeHighlighted) -> None: if len(self.screen_stack) > 1 or self.show_splash: diff --git a/strix/interface/utils.py b/strix/interface/utils.py index e188267..f0d5e61 100644 --- a/strix/interface/utils.py +++ b/strix/interface/utils.py @@ -722,9 +722,10 @@ def check_docker_connection() -> Any: error_text.append("DOCKER NOT AVAILABLE", style="bold red") error_text.append("\n\n", style="white") error_text.append("Cannot connect to Docker daemon.\n", style="white") - error_text.append("Please ensure Docker is installed and running.\n\n", style="white") - error_text.append("Try running: ", style="dim white") - error_text.append("sudo systemctl start docker", style="dim cyan") + error_text.append( + "Please ensure Docker Desktop is installed and running, and try running strix again.\n", + style="white", + ) panel = Panel( error_text, diff --git a/strix/runtime/__init__.py b/strix/runtime/__init__.py index 92e9e2e..49b83d9 100644 --- a/strix/runtime/__init__.py +++ b/strix/runtime/__init__.py @@ -3,6 +3,15 @@ import os from .runtime import AbstractRuntime +class SandboxInitializationError(Exception): + """Raised when sandbox initialization fails (e.g., Docker issues).""" + + def __init__(self, message: str, details: str | None = None): + super().__init__(message) + self.message = message + self.details = details + + def get_runtime() -> AbstractRuntime: runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker") @@ -16,4 +25,4 @@ def get_runtime() -> AbstractRuntime: ) -__all__ = ["AbstractRuntime", "get_runtime"] +__all__ = ["AbstractRuntime", "SandboxInitializationError", "get_runtime"] diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index 5abaa71..2f53804 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -4,28 +4,46 @@ import os import secrets import socket import time +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FuturesTimeoutError from pathlib import Path -from typing import cast +from typing import Any, cast import docker from docker.errors import DockerException, ImageNotFound, NotFound from docker.models.containers import Container +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import Timeout as RequestsTimeout +from . import SandboxInitializationError from .runtime import AbstractRuntime, SandboxInfo STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10") HOST_GATEWAY_HOSTNAME = "host.docker.internal" +DOCKER_TIMEOUT = 60 # seconds logger = logging.getLogger(__name__) class DockerRuntime(AbstractRuntime): def __init__(self) -> None: try: - self.client = docker.from_env() - except DockerException as e: + self.client = docker.from_env(timeout=DOCKER_TIMEOUT) + except (DockerException, RequestsConnectionError, RequestsTimeout) as e: logger.exception("Failed to connect to Docker daemon") - raise RuntimeError("Docker is not available or not configured correctly.") from e + if isinstance(e, RequestsConnectionError | RequestsTimeout): + raise SandboxInitializationError( + "Docker daemon unresponsive", + f"Connection timed out after {DOCKER_TIMEOUT} seconds. " + "Please ensure Docker Desktop is installed and running, " + "and try running strix again.", + ) from e + raise SandboxInitializationError( + "Docker is not available", + "Docker is not available or not configured correctly. " + "Please ensure Docker Desktop is installed and running, " + "and try running strix again.", + ) from e self._scan_container: Container | None = None self._tool_server_port: int | None = None @@ -39,6 +57,23 @@ class DockerRuntime(AbstractRuntime): s.bind(("", 0)) return cast("int", s.getsockname()[1]) + def _exec_run_with_timeout( + self, container: Container, cmd: str, timeout: int = DOCKER_TIMEOUT, **kwargs: Any + ) -> Any: + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(container.exec_run, cmd, **kwargs) + try: + return future.result(timeout=timeout) + except FuturesTimeoutError: + logger.exception(f"exec_run timed out after {timeout}s: {cmd[:100]}...") + raise SandboxInitializationError( + "Container command timed out", + f"Command timed out after {timeout} seconds. " + "Docker may be overloaded or unresponsive. " + "Please ensure Docker Desktop is installed and running, " + "and try running strix again.", + ) from None + def _get_scan_id(self, agent_id: str) -> str: try: from strix.telemetry.tracer import get_global_tracer @@ -134,7 +169,7 @@ class DockerRuntime(AbstractRuntime): self._initialize_container( container, caido_port, tool_server_port, tool_server_token ) - except DockerException as e: + except (DockerException, RequestsConnectionError, RequestsTimeout) as e: last_exception = e if attempt == max_retries - 1: logger.exception(f"Failed to create container after {max_retries} attempts") @@ -150,8 +185,19 @@ class DockerRuntime(AbstractRuntime): else: return container - raise RuntimeError( - f"Failed to create Docker container after {max_retries} attempts: {last_exception}" + if isinstance(last_exception, RequestsConnectionError | RequestsTimeout): + raise SandboxInitializationError( + "Failed to create sandbox container", + f"Docker daemon unresponsive after {max_retries} attempts " + f"(timed out after {DOCKER_TIMEOUT}s). " + "Please ensure Docker Desktop is installed and running, " + "and try running strix again.", + ) from last_exception + raise SandboxInitializationError( + "Failed to create sandbox container", + f"Container creation failed after {max_retries} attempts: {last_exception}. " + "Please ensure Docker Desktop is installed and running, " + "and try running strix again.", ) from last_exception def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912 @@ -196,7 +242,7 @@ class DockerRuntime(AbstractRuntime): except NotFound: pass - except DockerException as e: + except (DockerException, RequestsConnectionError, RequestsTimeout) as e: logger.warning(f"Failed to get container by name {container_name}: {e}") else: return container @@ -220,7 +266,7 @@ class DockerRuntime(AbstractRuntime): logger.info(f"Found existing container by label for scan {scan_id}") return container - except DockerException as e: + except (DockerException, RequestsConnectionError, RequestsTimeout) as e: logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e) logger.info("Creating new Docker container for scan %s", scan_id) @@ -230,15 +276,18 @@ class DockerRuntime(AbstractRuntime): self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str ) -> None: logger.info("Initializing Caido proxy on port %s", caido_port) - result = container.exec_run( + self._exec_run_with_timeout( + container, f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'", detach=False, ) time.sleep(5) - result = container.exec_run( - "bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester" + result = self._exec_run_with_timeout( + container, + "bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", + user="pentester", ) caido_token = result.output.decode().strip() if result.exit_code == 0 else ""