fix: add timeout handling for Docker operations and improve error messages

- Add SandboxInitializationError exception for sandbox/Docker failures
- Add 60-second timeout to Docker client initialization
- Add _exec_run_with_timeout() method using ThreadPoolExecutor for exec_run calls
- Catch ConnectionError and Timeout exceptions from requests library
- Add _handle_sandbox_error() and _handle_llm_error() methods in base_agent.py
- Handle sandbox_error_details tool in TUI for displaying errors
- Increase TUI truncation limits for better error visibility
- Update all Docker error messages with helpful hint:
  'Please ensure Docker Desktop is installed and running, and try running strix again.'
This commit is contained in:
0xallam
2026-01-08 16:11:15 -08:00
committed by Ahmed Allam
parent c327ce621f
commit 740fb3ed40
5 changed files with 193 additions and 80 deletions

View File

@@ -16,6 +16,7 @@ from jinja2 import (
from strix.llm import LLM, LLMConfig, LLMRequestFailedError from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content from strix.llm.utils import clean_content
from strix.runtime import SandboxInitializationError
from strix.tools import process_tool_invocations from strix.tools import process_tool_invocations
from .state import AgentState from .state import AgentState
@@ -145,18 +146,16 @@ class BaseAgent(metaclass=AgentMeta):
if self.state.parent_id is None and agents_graph_actions._root_agent_id is None: if self.state.parent_id is None and agents_graph_actions._root_agent_id is None:
agents_graph_actions._root_agent_id = self.state.agent_id agents_graph_actions._root_agent_id = self.state.agent_id
def cancel_current_execution(self) -> None:
if self._current_task and not self._current_task.done():
self._current_task.cancel()
self._current_task = None
async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915 async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
await self._initialize_sandbox_and_state(task)
from strix.telemetry.tracer import get_global_tracer from strix.telemetry.tracer import get_global_tracer
tracer = get_global_tracer() tracer = get_global_tracer()
try:
await self._initialize_sandbox_and_state(task)
except SandboxInitializationError as e:
return self._handle_sandbox_error(e, tracer)
while True: while True:
self._check_agent_messages(self.state) self._check_agent_messages(self.state)
@@ -232,37 +231,9 @@ class BaseAgent(metaclass=AgentMeta):
continue continue
except LLMRequestFailedError as e: except LLMRequestFailedError as e:
error_msg = str(e) result = self._handle_llm_error(e, tracer)
error_details = getattr(e, "details", None) if result is not None:
self.state.add_error(error_msg) return result
if self.non_interactive:
self.state.set_completed({"success": False, "error": error_msg})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
if error_details:
tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(
tracer._next_execution_id - 1, "failed", error_details
)
return {"success": False, "error": error_msg}
self.state.enter_waiting_state(llm_failed=True)
if tracer:
tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg)
if error_details:
tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(
tracer._next_execution_id - 1, "failed", error_details
)
continue continue
except (RuntimeError, ValueError, TypeError) as e: except (RuntimeError, ValueError, TypeError) as e:
@@ -439,18 +410,6 @@ class BaseAgent(metaclass=AgentMeta):
return False return False
async def _handle_iteration_error(
self,
error: RuntimeError | ValueError | TypeError | asyncio.CancelledError,
tracer: Optional["Tracer"],
) -> bool:
error_msg = f"Error in iteration {self.state.iteration}: {error!s}"
logger.exception(error_msg)
self.state.add_error(error_msg)
if tracer:
tracer.update_agent_status(self.state.agent_id, "error")
return True
def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912 def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
try: try:
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
@@ -535,3 +494,90 @@ class BaseAgent(metaclass=AgentMeta):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.warning(f"Error checking agent messages: {e}") logger.warning(f"Error checking agent messages: {e}")
return return
def _handle_sandbox_error(
self,
error: SandboxInitializationError,
tracer: Optional["Tracer"],
) -> dict[str, Any]:
error_msg = str(error.message)
error_details = error.details
self.state.add_error(error_msg)
if self.non_interactive:
self.state.set_completed({"success": False, "error": error_msg})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"sandbox_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return {"success": False, "error": error_msg, "details": error_details}
self.state.enter_waiting_state()
if tracer:
tracer.update_agent_status(self.state.agent_id, "sandbox_failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"sandbox_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return {"success": False, "error": error_msg, "details": error_details}
def _handle_llm_error(
self,
error: LLMRequestFailedError,
tracer: Optional["Tracer"],
) -> dict[str, Any] | None:
error_msg = str(error)
error_details = getattr(error, "details", None)
self.state.add_error(error_msg)
if self.non_interactive:
self.state.set_completed({"success": False, "error": error_msg})
if tracer:
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return {"success": False, "error": error_msg}
self.state.enter_waiting_state(llm_failed=True)
if tracer:
tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg)
if error_details:
exec_id = tracer.log_tool_execution_start(
self.state.agent_id,
"llm_error_details",
{"error": error_msg, "details": error_details},
)
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
return None
async def _handle_iteration_error(
self,
error: RuntimeError | ValueError | TypeError | asyncio.CancelledError,
tracer: Optional["Tracer"],
) -> bool:
error_msg = f"Error in iteration {self.state.iteration}: {error!s}"
logger.exception(error_msg)
self.state.add_error(error_msg)
if tracer:
tracer.update_agent_status(self.state.agent_id, "error")
return True
def cancel_current_execution(self) -> None:
if self._current_task and not self._current_task.done():
self._current_task.cancel()
self._current_task = None

View File

@@ -1629,15 +1629,8 @@ class StrixTUIApp(App): # type: ignore[misc]
text = Text() text = Text()
if tool_name == "llm_error_details": if tool_name in ("llm_error_details", "sandbox_error_details"):
text.append("✗ LLM Request Failed", style="red") return self._render_error_details(text, tool_name, args)
if args.get("details"):
details = str(args["details"])
if len(details) > 300:
details = details[:297] + "..."
text.append("\nDetails: ", style="dim")
text.append(details)
return text
text.append("→ Using tool ") text.append("→ Using tool ")
text.append(tool_name, style="bold blue") text.append(tool_name, style="bold blue")
@@ -1653,10 +1646,10 @@ class StrixTUIApp(App): # type: ignore[misc]
text.append(icon, style=style) text.append(icon, style=style)
if args: if args:
for k, v in list(args.items())[:2]: for k, v in list(args.items())[:5]:
str_v = str(v) str_v = str(v)
if len(str_v) > 80: if len(str_v) > 500:
str_v = str_v[:77] + "..." str_v = str_v[:497] + "..."
text.append("\n ") text.append("\n ")
text.append(k, style="dim") text.append(k, style="dim")
text.append(": ") text.append(": ")
@@ -1664,14 +1657,29 @@ class StrixTUIApp(App): # type: ignore[misc]
if status in ["completed", "failed", "error"] and result: if status in ["completed", "failed", "error"] and result:
result_str = str(result) result_str = str(result)
if len(result_str) > 150: if len(result_str) > 1000:
result_str = result_str[:147] + "..." result_str = result_str[:997] + "..."
text.append("\n") text.append("\n")
text.append("Result: ", style="bold") text.append("Result: ", style="bold")
text.append(result_str) text.append(result_str)
return text return text
def _render_error_details(self, text: Any, tool_name: str, args: dict[str, Any]) -> Any:
if tool_name == "llm_error_details":
text.append("✗ LLM Request Failed", style="red")
else:
text.append("✗ Sandbox Initialization Failed", style="red")
if args.get("error"):
text.append(f"\n{args['error']}", style="bold red")
if args.get("details"):
details = str(args["details"])
if len(details) > 1000:
details = details[:997] + "..."
text.append("\nDetails: ", style="dim")
text.append(details)
return text
@on(Tree.NodeHighlighted) # type: ignore[misc] @on(Tree.NodeHighlighted) # type: ignore[misc]
def handle_tree_highlight(self, event: Tree.NodeHighlighted) -> None: def handle_tree_highlight(self, event: Tree.NodeHighlighted) -> None:
if len(self.screen_stack) > 1 or self.show_splash: if len(self.screen_stack) > 1 or self.show_splash:

View File

@@ -722,9 +722,10 @@ def check_docker_connection() -> Any:
error_text.append("DOCKER NOT AVAILABLE", style="bold red") error_text.append("DOCKER NOT AVAILABLE", style="bold red")
error_text.append("\n\n", style="white") error_text.append("\n\n", style="white")
error_text.append("Cannot connect to Docker daemon.\n", style="white") error_text.append("Cannot connect to Docker daemon.\n", style="white")
error_text.append("Please ensure Docker is installed and running.\n\n", style="white") error_text.append(
error_text.append("Try running: ", style="dim white") "Please ensure Docker Desktop is installed and running, and try running strix again.\n",
error_text.append("sudo systemctl start docker", style="dim cyan") style="white",
)
panel = Panel( panel = Panel(
error_text, error_text,

View File

@@ -3,6 +3,15 @@ import os
from .runtime import AbstractRuntime from .runtime import AbstractRuntime
class SandboxInitializationError(Exception):
"""Raised when sandbox initialization fails (e.g., Docker issues)."""
def __init__(self, message: str, details: str | None = None):
super().__init__(message)
self.message = message
self.details = details
def get_runtime() -> AbstractRuntime: def get_runtime() -> AbstractRuntime:
runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker") runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker")
@@ -16,4 +25,4 @@ def get_runtime() -> AbstractRuntime:
) )
__all__ = ["AbstractRuntime", "get_runtime"] __all__ = ["AbstractRuntime", "SandboxInitializationError", "get_runtime"]

View File

@@ -4,28 +4,46 @@ import os
import secrets import secrets
import socket import socket
import time import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from pathlib import Path from pathlib import Path
from typing import cast from typing import Any, cast
import docker import docker
from docker.errors import DockerException, ImageNotFound, NotFound from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container from docker.models.containers import Container
from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.exceptions import Timeout as RequestsTimeout
from . import SandboxInitializationError
from .runtime import AbstractRuntime, SandboxInfo from .runtime import AbstractRuntime, SandboxInfo
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10") STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10")
HOST_GATEWAY_HOSTNAME = "host.docker.internal" HOST_GATEWAY_HOSTNAME = "host.docker.internal"
DOCKER_TIMEOUT = 60 # seconds
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DockerRuntime(AbstractRuntime): class DockerRuntime(AbstractRuntime):
def __init__(self) -> None: def __init__(self) -> None:
try: try:
self.client = docker.from_env() self.client = docker.from_env(timeout=DOCKER_TIMEOUT)
except DockerException as e: except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.exception("Failed to connect to Docker daemon") logger.exception("Failed to connect to Docker daemon")
raise RuntimeError("Docker is not available or not configured correctly.") from e if isinstance(e, RequestsConnectionError | RequestsTimeout):
raise SandboxInitializationError(
"Docker daemon unresponsive",
f"Connection timed out after {DOCKER_TIMEOUT} seconds. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from e
raise SandboxInitializationError(
"Docker is not available",
"Docker is not available or not configured correctly. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from e
self._scan_container: Container | None = None self._scan_container: Container | None = None
self._tool_server_port: int | None = None self._tool_server_port: int | None = None
@@ -39,6 +57,23 @@ class DockerRuntime(AbstractRuntime):
s.bind(("", 0)) s.bind(("", 0))
return cast("int", s.getsockname()[1]) return cast("int", s.getsockname()[1])
def _exec_run_with_timeout(
self, container: Container, cmd: str, timeout: int = DOCKER_TIMEOUT, **kwargs: Any
) -> Any:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(container.exec_run, cmd, **kwargs)
try:
return future.result(timeout=timeout)
except FuturesTimeoutError:
logger.exception(f"exec_run timed out after {timeout}s: {cmd[:100]}...")
raise SandboxInitializationError(
"Container command timed out",
f"Command timed out after {timeout} seconds. "
"Docker may be overloaded or unresponsive. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from None
def _get_scan_id(self, agent_id: str) -> str: def _get_scan_id(self, agent_id: str) -> str:
try: try:
from strix.telemetry.tracer import get_global_tracer from strix.telemetry.tracer import get_global_tracer
@@ -134,7 +169,7 @@ class DockerRuntime(AbstractRuntime):
self._initialize_container( self._initialize_container(
container, caido_port, tool_server_port, tool_server_token container, caido_port, tool_server_port, tool_server_token
) )
except DockerException as e: except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
last_exception = e last_exception = e
if attempt == max_retries - 1: if attempt == max_retries - 1:
logger.exception(f"Failed to create container after {max_retries} attempts") logger.exception(f"Failed to create container after {max_retries} attempts")
@@ -150,8 +185,19 @@ class DockerRuntime(AbstractRuntime):
else: else:
return container return container
raise RuntimeError( if isinstance(last_exception, RequestsConnectionError | RequestsTimeout):
f"Failed to create Docker container after {max_retries} attempts: {last_exception}" raise SandboxInitializationError(
"Failed to create sandbox container",
f"Docker daemon unresponsive after {max_retries} attempts "
f"(timed out after {DOCKER_TIMEOUT}s). "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from last_exception
raise SandboxInitializationError(
"Failed to create sandbox container",
f"Container creation failed after {max_retries} attempts: {last_exception}. "
"Please ensure Docker Desktop is installed and running, "
"and try running strix again.",
) from last_exception ) from last_exception
def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912 def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912
@@ -196,7 +242,7 @@ class DockerRuntime(AbstractRuntime):
except NotFound: except NotFound:
pass pass
except DockerException as e: except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.warning(f"Failed to get container by name {container_name}: {e}") logger.warning(f"Failed to get container by name {container_name}: {e}")
else: else:
return container return container
@@ -220,7 +266,7 @@ class DockerRuntime(AbstractRuntime):
logger.info(f"Found existing container by label for scan {scan_id}") logger.info(f"Found existing container by label for scan {scan_id}")
return container return container
except DockerException as e: except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e) logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e)
logger.info("Creating new Docker container for scan %s", scan_id) logger.info("Creating new Docker container for scan %s", scan_id)
@@ -230,15 +276,18 @@ class DockerRuntime(AbstractRuntime):
self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str
) -> None: ) -> None:
logger.info("Initializing Caido proxy on port %s", caido_port) logger.info("Initializing Caido proxy on port %s", caido_port)
result = container.exec_run( self._exec_run_with_timeout(
container,
f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'", f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'",
detach=False, detach=False,
) )
time.sleep(5) time.sleep(5)
result = container.exec_run( result = self._exec_run_with_timeout(
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester" container,
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'",
user="pentester",
) )
caido_token = result.output.decode().strip() if result.exit_code == 0 else "" caido_token = result.output.decode().strip() if result.exit_code == 0 else ""