fix: add timeout handling for Docker operations and improve error messages

- Add SandboxInitializationError exception for sandbox/Docker failures
- Add 60-second timeout to Docker client initialization
- Add _exec_run_with_timeout() method using ThreadPoolExecutor for exec_run calls
- Catch ConnectionError and Timeout exceptions from requests library
- Add _handle_sandbox_error() and _handle_llm_error() methods in base_agent.py
- Handle sandbox_error_details tool in TUI for displaying errors
- Increase TUI truncation limits for better error visibility
- Update all Docker error messages with helpful hint:
  'Please ensure Docker Desktop is installed and running, and try running strix again.'
This commit is contained in:
0xallam
2026-01-08 16:11:15 -08:00
committed by Ahmed Allam
parent c327ce621f
commit 740fb3ed40
5 changed files with 193 additions and 80 deletions

View File

@@ -16,6 +16,7 @@ from jinja2 import (
from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content
from strix.runtime import SandboxInitializationError
from strix.tools import process_tool_invocations
from .state import AgentState
@@ -145,18 +146,16 @@ class BaseAgent(metaclass=AgentMeta):
if self.state.parent_id is None and agents_graph_actions._root_agent_id is None:
agents_graph_actions._root_agent_id = self.state.agent_id
def cancel_current_execution(self) -> None:
if self._current_task and not self._current_task.done():
self._current_task.cancel()
self._current_task = None
async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
await self._initialize_sandbox_and_state(task)
from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer()
try:
await self._initialize_sandbox_and_state(task)
except SandboxInitializationError as e:
return self._handle_sandbox_error(e, tracer)
while True:
self._check_agent_messages(self.state)
@@ -232,37 +231,9 @@ class BaseAgent(metaclass=AgentMeta):
continue
except LLMRequestFailedError as e:
error_msg = str(e)
error_details = getattr(e, "details", None)
self.state.add_error(error_msg)
if self.non_interactive:
self.state.set_completed({"success": False, "error": error_msg})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
if error_details:
tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(
tracer._next_execution_id - 1, "failed", error_details
)
return {"success": False, "error": error_msg}
self.state.enter_waiting_state(llm_failed=True)
if tracer:
tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg)
if error_details:
tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(
tracer._next_execution_id - 1, "failed", error_details
)
result = self._handle_llm_error(e, tracer)
if result is not None:
return result
continue
except (RuntimeError, ValueError, TypeError) as e:
@@ -439,18 +410,6 @@ class BaseAgent(metaclass=AgentMeta):
return False
async def _handle_iteration_error(
self,
error: RuntimeError | ValueError | TypeError | asyncio.CancelledError,
tracer: Optional["Tracer"],
) -> bool:
error_msg = f"Error in iteration {self.state.iteration}: {error!s}"
logger.exception(error_msg)
self.state.add_error(error_msg)
if tracer:
tracer.update_agent_status(self.state.agent_id, "error")
return True
def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
try:
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
@@ -535,3 +494,90 @@ class BaseAgent(metaclass=AgentMeta):
logger = logging.getLogger(__name__)
logger.warning(f"Error checking agent messages: {e}")
return
def _handle_sandbox_error(
self,
error: SandboxInitializationError,
tracer: Optional["Tracer"],
) -> dict[str, Any]:
error_msg = str(error.message)
error_details = error.details
self.state.add_error(error_msg)
if self.non_interactive:
self.state.set_completed({"success": False, "error": error_msg})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"sandbox_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return {"success": False, "error": error_msg, "details": error_details}
self.state.enter_waiting_state()
if tracer:
tracer.update_agent_status(self.state.agent_id, "sandbox_failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"sandbox_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return {"success": False, "error": error_msg, "details": error_details}
def _handle_llm_error(
self,
error: LLMRequestFailedError,
tracer: Optional["Tracer"],
) -> dict[str, Any] | None:
error_msg = str(error)
error_details = getattr(error, "details", None)
self.state.add_error(error_msg)
if self.non_interactive:
self.state.set_completed({"success": False, "error": error_msg})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return {"success": False, "error": error_msg}
self.state.enter_waiting_state(llm_failed=True)
if tracer:
tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return None
async def _handle_iteration_error(
self,
error: RuntimeError | ValueError | TypeError | asyncio.CancelledError,
tracer: Optional["Tracer"],
) -> bool:
error_msg = f"Error in iteration {self.state.iteration}: {error!s}"
logger.exception(error_msg)
self.state.add_error(error_msg)
if tracer:
tracer.update_agent_status(self.state.agent_id, "error")
return True
def cancel_current_execution(self) -> None:
if self._current_task and not self._current_task.done():
self._current_task.cancel()
self._current_task = None

View File

@@ -1629,15 +1629,8 @@ class StrixTUIApp(App): # type: ignore[misc]
text = Text()
if tool_name == "llm_error_details":
text.append("✗ LLM Request Failed", style="red")
if args.get("details"):
details = str(args["details"])
if len(details) > 300:
details = details[:297] + "..."
text.append("\nDetails: ", style="dim")
text.append(details)
return text
if tool_name in ("llm_error_details", "sandbox_error_details"):
return self._render_error_details(text, tool_name, args)
text.append("→ Using tool ")
text.append(tool_name, style="bold blue")
@@ -1653,10 +1646,10 @@ class StrixTUIApp(App): # type: ignore[misc]
text.append(icon, style=style)
if args:
for k, v in list(args.items())[:2]:
for k, v in list(args.items())[:5]:
str_v = str(v)
if len(str_v) > 80:
str_v = str_v[:77] + "..."
if len(str_v) > 500:
str_v = str_v[:497] + "..."
text.append("\n ")
text.append(k, style="dim")
text.append(": ")
@@ -1664,14 +1657,29 @@ class StrixTUIApp(App): # type: ignore[misc]
if status in ["completed", "failed", "error"] and result:
result_str = str(result)
if len(result_str) > 150:
result_str = result_str[:147] + "..."
if len(result_str) > 1000:
result_str = result_str[:997] + "..."
text.append("\n")
text.append("Result: ", style="bold")
text.append(result_str)
return text
def _render_error_details(self, text: Any, tool_name: str, args: dict[str, Any]) -> Any:
if tool_name == "llm_error_details":
text.append("✗ LLM Request Failed", style="red")
else:
text.append("✗ Sandbox Initialization Failed", style="red")
if args.get("error"):
text.append(f"\n{args['error']}", style="bold red")
if args.get("details"):
details = str(args["details"])
if len(details) > 1000:
details = details[:997] + "..."
text.append("\nDetails: ", style="dim")
text.append(details)
return text
@on(Tree.NodeHighlighted) # type: ignore[misc]
def handle_tree_highlight(self, event: Tree.NodeHighlighted) -> None:
if len(self.screen_stack) > 1 or self.show_splash:

View File

@@ -722,9 +722,10 @@ def check_docker_connection() -> Any:
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
error_text.append("\n\n", style="white")
error_text.append("Cannot connect to Docker daemon.\n", style="white")
error_text.append("Please ensure Docker is installed and running.\n\n", style="white")
error_text.append("Try running: ", style="dim white")
error_text.append("sudo systemctl start docker", style="dim cyan")
error_text.append(
"Please ensure Docker Desktop is installed and running, and try running strix again.\n",
style="white",
)
panel = Panel(
error_text,

View File

@@ -3,6 +3,15 @@ import os
from .runtime import AbstractRuntime
class SandboxInitializationError(Exception):
"""Raised when sandbox initialization fails (e.g., Docker issues)."""
def __init__(self, message: str, details: str | None = None):
super().__init__(message)
self.message = message
self.details = details
def get_runtime() -> AbstractRuntime:
runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker")
@@ -16,4 +25,4 @@ def get_runtime() -> AbstractRuntime:
)
__all__ = ["AbstractRuntime", "get_runtime"]
__all__ = ["AbstractRuntime", "SandboxInitializationError", "get_runtime"]

View File

@@ -4,28 +4,46 @@ import os
import secrets
import socket
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from pathlib import Path
from typing import cast
from typing import Any, cast
import docker
from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container
from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.exceptions import Timeout as RequestsTimeout
from . import SandboxInitializationError
from .runtime import AbstractRuntime, SandboxInfo
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10")
HOST_GATEWAY_HOSTNAME = "host.docker.internal"
DOCKER_TIMEOUT = 60 # seconds
logger = logging.getLogger(__name__)
class DockerRuntime(AbstractRuntime):
def __init__(self) -> None:
try:
self.client = docker.from_env()
except DockerException as e:
self.client = docker.from_env(timeout=DOCKER_TIMEOUT)
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.exception("Failed to connect to Docker daemon")
raise RuntimeError("Docker is not available or not configured correctly.") from e
if isinstance(e, RequestsConnectionError | RequestsTimeout):
raise SandboxInitializationError(
"Docker daemon unresponsive",
f"Connection timed out after {DOCKER_TIMEOUT} seconds. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from e
raise SandboxInitializationError(
"Docker is not available",
"Docker is not available or not configured correctly. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from e
self._scan_container: Container | None = None
self._tool_server_port: int | None = None
@@ -39,6 +57,23 @@ class DockerRuntime(AbstractRuntime):
s.bind(("", 0))
return cast("int", s.getsockname()[1])
def _exec_run_with_timeout(
self, container: Container, cmd: str, timeout: int = DOCKER_TIMEOUT, **kwargs: Any
) -> Any:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(container.exec_run, cmd, **kwargs)
try:
return future.result(timeout=timeout)
except FuturesTimeoutError:
logger.exception(f"exec_run timed out after {timeout}s: {cmd[:100]}...")
raise SandboxInitializationError(
"Container command timed out",
f"Command timed out after {timeout} seconds. "
"Docker may be overloaded or unresponsive. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from None
def _get_scan_id(self, agent_id: str) -> str:
try:
from strix.telemetry.tracer import get_global_tracer
@@ -134,7 +169,7 @@ class DockerRuntime(AbstractRuntime):
self._initialize_container(
container, caido_port, tool_server_port, tool_server_token
)
except DockerException as e:
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
last_exception = e
if attempt == max_retries - 1:
logger.exception(f"Failed to create container after {max_retries} attempts")
@@ -150,8 +185,19 @@ class DockerRuntime(AbstractRuntime):
else:
return container
raise RuntimeError(
f"Failed to create Docker container after {max_retries} attempts: {last_exception}"
if isinstance(last_exception, RequestsConnectionError | RequestsTimeout):
raise SandboxInitializationError(
"Failed to create sandbox container",
f"Docker daemon unresponsive after {max_retries} attempts "
f"(timed out after {DOCKER_TIMEOUT}s). "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from last_exception
raise SandboxInitializationError(
"Failed to create sandbox container",
f"Container creation failed after {max_retries} attempts: {last_exception}. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from last_exception
def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912
@@ -196,7 +242,7 @@ class DockerRuntime(AbstractRuntime):
except NotFound:
pass
except DockerException as e:
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.warning(f"Failed to get container by name {container_name}: {e}")
else:
return container
@@ -220,7 +266,7 @@ class DockerRuntime(AbstractRuntime):
logger.info(f"Found existing container by label for scan {scan_id}")
return container
except DockerException as e:
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e)
logger.info("Creating new Docker container for scan %s", scan_id)
@@ -230,15 +276,18 @@ class DockerRuntime(AbstractRuntime):
self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str
) -> None:
logger.info("Initializing Caido proxy on port %s", caido_port)
result = container.exec_run(
self._exec_run_with_timeout(
container,
f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'",
detach=False,
)
time.sleep(5)
result = container.exec_run(
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester"
result = self._exec_run_with_timeout(
container,
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'",
user="pentester",
)
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""