fix: add timeout handling for Docker operations and improve error messages
- Add SandboxInitializationError exception for sandbox/Docker failures - Add 60-second timeout to Docker client initialization - Add _exec_run_with_timeout() method using ThreadPoolExecutor for exec_run calls - Catch ConnectionError and Timeout exceptions from requests library - Add _handle_sandbox_error() and _handle_llm_error() methods in base_agent.py - Handle sandbox_error_details tool in TUI for displaying errors - Increase TUI truncation limits for better error visibility - Update all Docker error messages with helpful hint: 'Please ensure Docker Desktop is installed and running, and try running strix again.'
This commit is contained in:
@@ -16,6 +16,7 @@ from jinja2 import (
|
||||
|
||||
from strix.llm import LLM, LLMConfig, LLMRequestFailedError
|
||||
from strix.llm.utils import clean_content
|
||||
from strix.runtime import SandboxInitializationError
|
||||
from strix.tools import process_tool_invocations
|
||||
|
||||
from .state import AgentState
|
||||
@@ -145,18 +146,16 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
if self.state.parent_id is None and agents_graph_actions._root_agent_id is None:
|
||||
agents_graph_actions._root_agent_id = self.state.agent_id
|
||||
|
||||
def cancel_current_execution(self) -> None:
|
||||
if self._current_task and not self._current_task.done():
|
||||
self._current_task.cancel()
|
||||
self._current_task = None
|
||||
|
||||
async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR0915
|
||||
await self._initialize_sandbox_and_state(task)
|
||||
|
||||
from strix.telemetry.tracer import get_global_tracer
|
||||
|
||||
tracer = get_global_tracer()
|
||||
|
||||
try:
|
||||
await self._initialize_sandbox_and_state(task)
|
||||
except SandboxInitializationError as e:
|
||||
return self._handle_sandbox_error(e, tracer)
|
||||
|
||||
while True:
|
||||
self._check_agent_messages(self.state)
|
||||
|
||||
@@ -232,37 +231,9 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
continue
|
||||
|
||||
except LLMRequestFailedError as e:
|
||||
error_msg = str(e)
|
||||
error_details = getattr(e, "details", None)
|
||||
self.state.add_error(error_msg)
|
||||
|
||||
if self.non_interactive:
|
||||
self.state.set_completed({"success": False, "error": error_msg})
|
||||
if tracer:
|
||||
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
|
||||
if error_details:
|
||||
tracer.log_tool_execution_start(
|
||||
self.state.agent_id,
|
||||
"llm_error_details",
|
||||
{"error": error_msg, "details": error_details},
|
||||
)
|
||||
tracer.update_tool_execution(
|
||||
tracer._next_execution_id - 1, "failed", error_details
|
||||
)
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
self.state.enter_waiting_state(llm_failed=True)
|
||||
if tracer:
|
||||
tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg)
|
||||
if error_details:
|
||||
tracer.log_tool_execution_start(
|
||||
self.state.agent_id,
|
||||
"llm_error_details",
|
||||
{"error": error_msg, "details": error_details},
|
||||
)
|
||||
tracer.update_tool_execution(
|
||||
tracer._next_execution_id - 1, "failed", error_details
|
||||
)
|
||||
result = self._handle_llm_error(e, tracer)
|
||||
if result is not None:
|
||||
return result
|
||||
continue
|
||||
|
||||
except (RuntimeError, ValueError, TypeError) as e:
|
||||
@@ -439,18 +410,6 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
|
||||
return False
|
||||
|
||||
async def _handle_iteration_error(
|
||||
self,
|
||||
error: RuntimeError | ValueError | TypeError | asyncio.CancelledError,
|
||||
tracer: Optional["Tracer"],
|
||||
) -> bool:
|
||||
error_msg = f"Error in iteration {self.state.iteration}: {error!s}"
|
||||
logger.exception(error_msg)
|
||||
self.state.add_error(error_msg)
|
||||
if tracer:
|
||||
tracer.update_agent_status(self.state.agent_id, "error")
|
||||
return True
|
||||
|
||||
def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
|
||||
try:
|
||||
from strix.tools.agents_graph.agents_graph_actions import _agent_graph, _agent_messages
|
||||
@@ -535,3 +494,90 @@ class BaseAgent(metaclass=AgentMeta):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Error checking agent messages: {e}")
|
||||
return
|
||||
|
||||
def _handle_sandbox_error(
|
||||
self,
|
||||
error: SandboxInitializationError,
|
||||
tracer: Optional["Tracer"],
|
||||
) -> dict[str, Any]:
|
||||
error_msg = str(error.message)
|
||||
error_details = error.details
|
||||
self.state.add_error(error_msg)
|
||||
|
||||
if self.non_interactive:
|
||||
self.state.set_completed({"success": False, "error": error_msg})
|
||||
if tracer:
|
||||
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
|
||||
if error_details:
|
||||
exec_id = tracer.log_tool_execution_start(
|
||||
self.state.agent_id,
|
||||
"sandbox_error_details",
|
||||
{"error": error_msg, "details": error_details},
|
||||
)
|
||||
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
|
||||
return {"success": False, "error": error_msg, "details": error_details}
|
||||
|
||||
self.state.enter_waiting_state()
|
||||
if tracer:
|
||||
tracer.update_agent_status(self.state.agent_id, "sandbox_failed", error_msg)
|
||||
if error_details:
|
||||
exec_id = tracer.log_tool_execution_start(
|
||||
self.state.agent_id,
|
||||
"sandbox_error_details",
|
||||
{"error": error_msg, "details": error_details},
|
||||
)
|
||||
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
|
||||
|
||||
return {"success": False, "error": error_msg, "details": error_details}
|
||||
|
||||
def _handle_llm_error(
|
||||
self,
|
||||
error: LLMRequestFailedError,
|
||||
tracer: Optional["Tracer"],
|
||||
) -> dict[str, Any] | None:
|
||||
error_msg = str(error)
|
||||
error_details = getattr(error, "details", None)
|
||||
self.state.add_error(error_msg)
|
||||
|
||||
if self.non_interactive:
|
||||
self.state.set_completed({"success": False, "error": error_msg})
|
||||
if tracer:
|
||||
tracer.update_agent_status(self.state.agent_id, "failed", error_msg)
|
||||
if error_details:
|
||||
exec_id = tracer.log_tool_execution_start(
|
||||
self.state.agent_id,
|
||||
"llm_error_details",
|
||||
{"error": error_msg, "details": error_details},
|
||||
)
|
||||
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
self.state.enter_waiting_state(llm_failed=True)
|
||||
if tracer:
|
||||
tracer.update_agent_status(self.state.agent_id, "llm_failed", error_msg)
|
||||
if error_details:
|
||||
exec_id = tracer.log_tool_execution_start(
|
||||
self.state.agent_id,
|
||||
"llm_error_details",
|
||||
{"error": error_msg, "details": error_details},
|
||||
)
|
||||
tracer.update_tool_execution(exec_id, "failed", {"details": error_details})
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_iteration_error(
|
||||
self,
|
||||
error: RuntimeError | ValueError | TypeError | asyncio.CancelledError,
|
||||
tracer: Optional["Tracer"],
|
||||
) -> bool:
|
||||
error_msg = f"Error in iteration {self.state.iteration}: {error!s}"
|
||||
logger.exception(error_msg)
|
||||
self.state.add_error(error_msg)
|
||||
if tracer:
|
||||
tracer.update_agent_status(self.state.agent_id, "error")
|
||||
return True
|
||||
|
||||
def cancel_current_execution(self) -> None:
|
||||
if self._current_task and not self._current_task.done():
|
||||
self._current_task.cancel()
|
||||
self._current_task = None
|
||||
|
||||
@@ -1629,15 +1629,8 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
|
||||
text = Text()
|
||||
|
||||
if tool_name == "llm_error_details":
|
||||
text.append("✗ LLM Request Failed", style="red")
|
||||
if args.get("details"):
|
||||
details = str(args["details"])
|
||||
if len(details) > 300:
|
||||
details = details[:297] + "..."
|
||||
text.append("\nDetails: ", style="dim")
|
||||
text.append(details)
|
||||
return text
|
||||
if tool_name in ("llm_error_details", "sandbox_error_details"):
|
||||
return self._render_error_details(text, tool_name, args)
|
||||
|
||||
text.append("→ Using tool ")
|
||||
text.append(tool_name, style="bold blue")
|
||||
@@ -1653,10 +1646,10 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
text.append(icon, style=style)
|
||||
|
||||
if args:
|
||||
for k, v in list(args.items())[:2]:
|
||||
for k, v in list(args.items())[:5]:
|
||||
str_v = str(v)
|
||||
if len(str_v) > 80:
|
||||
str_v = str_v[:77] + "..."
|
||||
if len(str_v) > 500:
|
||||
str_v = str_v[:497] + "..."
|
||||
text.append("\n ")
|
||||
text.append(k, style="dim")
|
||||
text.append(": ")
|
||||
@@ -1664,14 +1657,29 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
|
||||
if status in ["completed", "failed", "error"] and result:
|
||||
result_str = str(result)
|
||||
if len(result_str) > 150:
|
||||
result_str = result_str[:147] + "..."
|
||||
if len(result_str) > 1000:
|
||||
result_str = result_str[:997] + "..."
|
||||
text.append("\n")
|
||||
text.append("Result: ", style="bold")
|
||||
text.append(result_str)
|
||||
|
||||
return text
|
||||
|
||||
def _render_error_details(self, text: Any, tool_name: str, args: dict[str, Any]) -> Any:
|
||||
if tool_name == "llm_error_details":
|
||||
text.append("✗ LLM Request Failed", style="red")
|
||||
else:
|
||||
text.append("✗ Sandbox Initialization Failed", style="red")
|
||||
if args.get("error"):
|
||||
text.append(f"\n{args['error']}", style="bold red")
|
||||
if args.get("details"):
|
||||
details = str(args["details"])
|
||||
if len(details) > 1000:
|
||||
details = details[:997] + "..."
|
||||
text.append("\nDetails: ", style="dim")
|
||||
text.append(details)
|
||||
return text
|
||||
|
||||
@on(Tree.NodeHighlighted) # type: ignore[misc]
|
||||
def handle_tree_highlight(self, event: Tree.NodeHighlighted) -> None:
|
||||
if len(self.screen_stack) > 1 or self.show_splash:
|
||||
|
||||
@@ -722,9 +722,10 @@ def check_docker_connection() -> Any:
|
||||
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
|
||||
error_text.append("\n\n", style="white")
|
||||
error_text.append("Cannot connect to Docker daemon.\n", style="white")
|
||||
error_text.append("Please ensure Docker is installed and running.\n\n", style="white")
|
||||
error_text.append("Try running: ", style="dim white")
|
||||
error_text.append("sudo systemctl start docker", style="dim cyan")
|
||||
error_text.append(
|
||||
"Please ensure Docker Desktop is installed and running, and try running strix again.\n",
|
||||
style="white",
|
||||
)
|
||||
|
||||
panel = Panel(
|
||||
error_text,
|
||||
|
||||
@@ -3,6 +3,15 @@ import os
|
||||
from .runtime import AbstractRuntime
|
||||
|
||||
|
||||
class SandboxInitializationError(Exception):
|
||||
"""Raised when sandbox initialization fails (e.g., Docker issues)."""
|
||||
|
||||
def __init__(self, message: str, details: str | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.details = details
|
||||
|
||||
|
||||
def get_runtime() -> AbstractRuntime:
|
||||
runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker")
|
||||
|
||||
@@ -16,4 +25,4 @@ def get_runtime() -> AbstractRuntime:
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AbstractRuntime", "get_runtime"]
|
||||
__all__ = ["AbstractRuntime", "SandboxInitializationError", "get_runtime"]
|
||||
|
||||
@@ -4,28 +4,46 @@ import os
|
||||
import secrets
|
||||
import socket
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import docker
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from docker.models.containers import Container
|
||||
from requests.exceptions import ConnectionError as RequestsConnectionError
|
||||
from requests.exceptions import Timeout as RequestsTimeout
|
||||
|
||||
from . import SandboxInitializationError
|
||||
from .runtime import AbstractRuntime, SandboxInfo
|
||||
|
||||
|
||||
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10")
|
||||
HOST_GATEWAY_HOSTNAME = "host.docker.internal"
|
||||
DOCKER_TIMEOUT = 60 # seconds
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DockerRuntime(AbstractRuntime):
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
self.client = docker.from_env()
|
||||
except DockerException as e:
|
||||
self.client = docker.from_env(timeout=DOCKER_TIMEOUT)
|
||||
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
|
||||
logger.exception("Failed to connect to Docker daemon")
|
||||
raise RuntimeError("Docker is not available or not configured correctly.") from e
|
||||
if isinstance(e, RequestsConnectionError | RequestsTimeout):
|
||||
raise SandboxInitializationError(
|
||||
"Docker daemon unresponsive",
|
||||
f"Connection timed out after {DOCKER_TIMEOUT} seconds. "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
) from e
|
||||
raise SandboxInitializationError(
|
||||
"Docker is not available",
|
||||
"Docker is not available or not configured correctly. "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
) from e
|
||||
|
||||
self._scan_container: Container | None = None
|
||||
self._tool_server_port: int | None = None
|
||||
@@ -39,6 +57,23 @@ class DockerRuntime(AbstractRuntime):
|
||||
s.bind(("", 0))
|
||||
return cast("int", s.getsockname()[1])
|
||||
|
||||
def _exec_run_with_timeout(
|
||||
self, container: Container, cmd: str, timeout: int = DOCKER_TIMEOUT, **kwargs: Any
|
||||
) -> Any:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(container.exec_run, cmd, **kwargs)
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except FuturesTimeoutError:
|
||||
logger.exception(f"exec_run timed out after {timeout}s: {cmd[:100]}...")
|
||||
raise SandboxInitializationError(
|
||||
"Container command timed out",
|
||||
f"Command timed out after {timeout} seconds. "
|
||||
"Docker may be overloaded or unresponsive. "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
) from None
|
||||
|
||||
def _get_scan_id(self, agent_id: str) -> str:
|
||||
try:
|
||||
from strix.telemetry.tracer import get_global_tracer
|
||||
@@ -134,7 +169,7 @@ class DockerRuntime(AbstractRuntime):
|
||||
self._initialize_container(
|
||||
container, caido_port, tool_server_port, tool_server_token
|
||||
)
|
||||
except DockerException as e:
|
||||
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
|
||||
last_exception = e
|
||||
if attempt == max_retries - 1:
|
||||
logger.exception(f"Failed to create container after {max_retries} attempts")
|
||||
@@ -150,8 +185,19 @@ class DockerRuntime(AbstractRuntime):
|
||||
else:
|
||||
return container
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to create Docker container after {max_retries} attempts: {last_exception}"
|
||||
if isinstance(last_exception, RequestsConnectionError | RequestsTimeout):
|
||||
raise SandboxInitializationError(
|
||||
"Failed to create sandbox container",
|
||||
f"Docker daemon unresponsive after {max_retries} attempts "
|
||||
f"(timed out after {DOCKER_TIMEOUT}s). "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
) from last_exception
|
||||
raise SandboxInitializationError(
|
||||
"Failed to create sandbox container",
|
||||
f"Container creation failed after {max_retries} attempts: {last_exception}. "
|
||||
"Please ensure Docker Desktop is installed and running, "
|
||||
"and try running strix again.",
|
||||
) from last_exception
|
||||
|
||||
def _get_or_create_scan_container(self, scan_id: str) -> Container: # noqa: PLR0912
|
||||
@@ -196,7 +242,7 @@ class DockerRuntime(AbstractRuntime):
|
||||
|
||||
except NotFound:
|
||||
pass
|
||||
except DockerException as e:
|
||||
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
|
||||
logger.warning(f"Failed to get container by name {container_name}: {e}")
|
||||
else:
|
||||
return container
|
||||
@@ -220,7 +266,7 @@ class DockerRuntime(AbstractRuntime):
|
||||
|
||||
logger.info(f"Found existing container by label for scan {scan_id}")
|
||||
return container
|
||||
except DockerException as e:
|
||||
except (DockerException, RequestsConnectionError, RequestsTimeout) as e:
|
||||
logger.warning("Failed to find existing container by label for scan %s: %s", scan_id, e)
|
||||
|
||||
logger.info("Creating new Docker container for scan %s", scan_id)
|
||||
@@ -230,15 +276,18 @@ class DockerRuntime(AbstractRuntime):
|
||||
self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str
|
||||
) -> None:
|
||||
logger.info("Initializing Caido proxy on port %s", caido_port)
|
||||
result = container.exec_run(
|
||||
self._exec_run_with_timeout(
|
||||
container,
|
||||
f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'",
|
||||
detach=False,
|
||||
)
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
result = container.exec_run(
|
||||
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester"
|
||||
result = self._exec_run_with_timeout(
|
||||
container,
|
||||
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'",
|
||||
user="pentester",
|
||||
)
|
||||
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user