From 693ef160605a3a0e03f931caa63fa9e19be8d1a0 Mon Sep 17 00:00:00 2001 From: 0xallam Date: Fri, 16 Jan 2026 00:21:02 -0800 Subject: [PATCH] fix(runtime): parallel tool execution and remove signal handlers - Add ThreadPoolExecutor in agent_worker for parallel request execution - Add request_id correlation to prevent response mismatch between concurrent requests - Add background listener thread per agent to dispatch responses to correct futures - Add --timeout argument for hard request timeout (default: 120s from config) - Remove signal handlers from terminal_manager, python_manager, tab_manager (use atexit only) - Replace SIGALRM timeout in python_instance with threading-based timeout This fixes requests getting queued behind slow operations and timeouts. --- containers/Dockerfile | 2 + strix/runtime/docker_runtime.py | 3 +- strix/runtime/tool_server.py | 130 ++++++++++++++++++----- strix/tools/browser/tab_manager.py | 12 --- strix/tools/python/python_instance.py | 74 +++++++------ strix/tools/python/python_manager.py | 12 --- strix/tools/terminal/terminal_manager.py | 12 --- 7 files changed, 144 insertions(+), 101 deletions(-) diff --git a/containers/Dockerfile b/containers/Dockerfile index c8b90de..40d9573 100644 --- a/containers/Dockerfile +++ b/containers/Dockerfile @@ -168,6 +168,8 @@ RUN /app/venv/bin/pip install -r /home/pentester/tools/jwt_tool/requirements.txt RUN echo "# Sandbox Environment" > README.md COPY strix/__init__.py strix/ +COPY strix/config/ /app/strix/config/ +COPY strix/utils/ /app/strix/utils/ COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/ COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/ diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index accc1b3..ecf7fda 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -297,11 +297,12 @@ class DockerRuntime(AbstractRuntime): ) caido_token = result.output.decode().strip() if result.exit_code == 0 else "" + execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120" container.exec_run( f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && " f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} " f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} " - f"--host 0.0.0.0 --port {tool_server_port} &'", + f"--host 0.0.0.0 --port {tool_server_port} --timeout {execution_timeout} &'", detach=True, user="pentester", ) diff --git a/strix/runtime/tool_server.py b/strix/runtime/tool_server.py index 6461f8c..5feebc8 100644 --- a/strix/runtime/tool_server.py +++ b/strix/runtime/tool_server.py @@ -2,12 +2,16 @@ from __future__ import annotations import argparse import asyncio +import contextlib import logging import os +import queue as stdlib_queue import signal import sys +import threading from multiprocessing import Process, Queue from typing import Any +from uuid import uuid4 import uvicorn from fastapi import Depends, FastAPI, HTTPException, status @@ -23,9 +27,16 @@ parser = argparse.ArgumentParser(description="Start Strix tool server") parser.add_argument("--token", required=True, help="Authentication token") parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec parser.add_argument("--port", type=int, required=True, help="Port to bind to") +parser.add_argument( + "--timeout", + type=int, + default=120, + help="Hard timeout in seconds for each request execution (default: 120)", +) args = parser.parse_args() EXPECTED_TOKEN = args.token +REQUEST_TIMEOUT = args.timeout app = FastAPI() security = HTTPBearer() @@ -34,6 +45,8 @@ security_dependency = Depends(security) agent_processes: dict[str, dict[str, Any]] = {} agent_queues: dict[str, dict[str, Queue[Any]]] = {} +pending_responses: dict[str, dict[str, asyncio.Future[Any]]] = {} +agent_listeners: dict[str, dict[str, Any]] = {} def verify_token(credentials: HTTPAuthorizationCredentials) -> str: @@ -72,37 +85,79 @@ def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queu root_logger.handlers = [null_handler] root_logger.setLevel(logging.CRITICAL) + from concurrent.futures import ThreadPoolExecutor + from strix.tools.argument_parser import ArgumentConversionError, convert_arguments from strix.tools.registry import get_tool_by_name - while True: - try: - request = request_queue.get() + def _execute_request(request: dict[str, Any]) -> None: + request_id = request.get("request_id", "") + tool_name = request["tool_name"] + kwargs = request["kwargs"] - if request is None: + try: + tool_func = get_tool_by_name(tool_name) + if not tool_func: + response_queue.put( + {"request_id": request_id, "error": f"Tool '{tool_name}' not found"} + ) + return + + converted_kwargs = convert_arguments(tool_func, kwargs) + result = tool_func(**converted_kwargs) + + response_queue.put({"request_id": request_id, "result": result}) + + except (ArgumentConversionError, ValidationError) as e: + response_queue.put({"request_id": request_id, "error": f"Invalid arguments: {e}"}) + except (RuntimeError, ValueError, ImportError) as e: + response_queue.put({"request_id": request_id, "error": f"Tool execution error: {e}"}) + except Exception as e: # noqa: BLE001 + response_queue.put({"request_id": request_id, "error": f"Unexpected error: {e}"}) + + with ThreadPoolExecutor() as executor: + while True: + try: + request = request_queue.get() + + if request is None: + break + + executor.submit(_execute_request, request) + + except (RuntimeError, ValueError, ImportError) as e: + response_queue.put({"error": f"Worker error: {e}"}) + + +def _ensure_response_listener(agent_id: str, response_queue: Queue[Any]) -> None: + if agent_id in agent_listeners: + return + + stop_event = threading.Event() + loop = asyncio.get_event_loop() + + def _listener() -> None: + while not stop_event.is_set(): + try: + item = response_queue.get(timeout=0.5) + except stdlib_queue.Empty: + continue + except (BrokenPipeError, EOFError): break - tool_name = request["tool_name"] - kwargs = request["kwargs"] + request_id = item.get("request_id") + if not request_id or agent_id not in pending_responses: + continue - try: - tool_func = get_tool_by_name(tool_name) - if not tool_func: - response_queue.put({"error": f"Tool '{tool_name}' not found"}) - continue + future = pending_responses[agent_id].pop(request_id, None) + if future and not future.done(): + with contextlib.suppress(RuntimeError): + loop.call_soon_threadsafe(future.set_result, item) - converted_kwargs = convert_arguments(tool_func, kwargs) - result = tool_func(**converted_kwargs) + listener_thread = threading.Thread(target=_listener, daemon=True) + listener_thread.start() - response_queue.put({"result": result}) - - except (ArgumentConversionError, ValidationError) as e: - response_queue.put({"error": f"Invalid arguments: {e}"}) - except (RuntimeError, ValueError, ImportError) as e: - response_queue.put({"error": f"Tool execution error: {e}"}) - - except (RuntimeError, ValueError, ImportError) as e: - response_queue.put({"error": f"Worker error: {e}"}) + agent_listeners[agent_id] = {"thread": listener_thread, "stop_event": stop_event} def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]: @@ -117,6 +172,9 @@ def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]: agent_processes[agent_id] = {"process": process, "pid": process.pid} agent_queues[agent_id] = {"request": request_queue, "response": response_queue} + pending_responses[agent_id] = {} + + _ensure_response_listener(agent_id, response_queue) return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"] @@ -127,18 +185,31 @@ async def execute_tool( ) -> ToolExecutionResponse: verify_token(credentials) - request_queue, response_queue = ensure_agent_process(request.agent_id) + request_queue, _response_queue = ensure_agent_process(request.agent_id) - request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs}) + loop = asyncio.get_event_loop() + req_id = uuid4().hex + future: asyncio.Future[Any] = loop.create_future() + pending_responses[request.agent_id][req_id] = future + + request_queue.put( + { + "request_id": req_id, + "tool_name": request.tool_name, + "kwargs": request.kwargs, + } + ) try: - loop = asyncio.get_event_loop() - response = await loop.run_in_executor(None, response_queue.get) + response = await asyncio.wait_for(future, timeout=REQUEST_TIMEOUT) if "error" in response: return ToolExecutionResponse(error=response["error"]) return ToolExecutionResponse(result=response.get("result")) + except TimeoutError: + pending_responses[request.agent_id].pop(req_id, None) + return ToolExecutionResponse(error=f"Request timed out after {REQUEST_TIMEOUT} seconds") except (RuntimeError, ValueError, OSError) as e: return ToolExecutionResponse(error=f"Worker error: {e}") @@ -168,6 +239,9 @@ async def health_check() -> dict[str, Any]: def cleanup_all_agents() -> None: for agent_id in list(agent_processes.keys()): try: + if agent_id in agent_listeners: + agent_listeners[agent_id]["stop_event"].set() + agent_queues[agent_id]["request"].put(None) process = agent_processes[agent_id]["process"] @@ -180,6 +254,10 @@ def cleanup_all_agents() -> None: if process.is_alive(): process.kill() + if agent_id in agent_listeners: + listener_thread = agent_listeners[agent_id]["thread"] + listener_thread.join(timeout=0.5) + except (BrokenPipeError, EOFError, OSError): pass except (RuntimeError, ValueError) as e: diff --git a/strix/tools/browser/tab_manager.py b/strix/tools/browser/tab_manager.py index 3b4b674..a77dda2 100644 --- a/strix/tools/browser/tab_manager.py +++ b/strix/tools/browser/tab_manager.py @@ -1,7 +1,5 @@ import atexit import contextlib -import signal -import sys import threading from typing import Any @@ -324,16 +322,6 @@ class BrowserTabManager: def _register_cleanup_handlers(self) -> None: atexit.register(self.close_all) - signal.signal(signal.SIGTERM, self._signal_handler) - signal.signal(signal.SIGINT, self._signal_handler) - - if hasattr(signal, "SIGHUP"): - signal.signal(signal.SIGHUP, self._signal_handler) - - def _signal_handler(self, _signum: int, _frame: Any) -> None: - self.close_all() - sys.exit(0) - _browser_tab_manager = BrowserTabManager() diff --git a/strix/tools/python/python_instance.py b/strix/tools/python/python_instance.py index f1a5c21..0bf3118 100644 --- a/strix/tools/python/python_instance.py +++ b/strix/tools/python/python_instance.py @@ -1,5 +1,4 @@ import io -import signal import sys import threading from typing import Any @@ -57,28 +56,6 @@ class PythonInstance: } return None - def _setup_execution_environment(self, timeout: int) -> tuple[Any, io.StringIO, io.StringIO]: - stdout_capture = io.StringIO() - stderr_capture = io.StringIO() - - def timeout_handler(signum: int, frame: Any) -> None: - raise TimeoutError(f"Code execution timed out after {timeout} seconds") - - old_handler = signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(timeout) - - sys.stdout = stdout_capture - sys.stderr = stderr_capture - - return old_handler, stdout_capture, stderr_capture - - def _cleanup_execution_environment( - self, old_handler: Any, old_stdout: Any, old_stderr: Any - ) -> None: - signal.signal(signal.SIGALRM, old_handler) - sys.stdout = old_stdout - sys.stderr = old_stderr - def _truncate_output(self, content: str, max_length: int, suffix: str) -> str: if len(content) > max_length: return content[:max_length] + suffix @@ -142,27 +119,48 @@ class PythonInstance: return session_error with self._execution_lock: + result_container: dict[str, Any] = {} + stdout_capture = io.StringIO() + stderr_capture = io.StringIO() + old_stdout, old_stderr = sys.stdout, sys.stderr - try: - old_handler, stdout_capture, stderr_capture = self._setup_execution_environment( - timeout + def _run_code() -> None: + try: + sys.stdout = stdout_capture + sys.stderr = stderr_capture + execution_result = self.shell.run_cell(code, silent=False, store_history=True) + result_container["execution_result"] = execution_result + result_container["stdout"] = stdout_capture.getvalue() + result_container["stderr"] = stderr_capture.getvalue() + except (KeyboardInterrupt, SystemExit) as e: + result_container["error"] = e + except Exception as e: # noqa: BLE001 + result_container["error"] = e + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + + exec_thread = threading.Thread(target=_run_code, daemon=True) + exec_thread.start() + exec_thread.join(timeout=timeout) + + if exec_thread.is_alive(): + return self._handle_execution_error( + TimeoutError(f"Code execution timed out after {timeout} seconds") ) - try: - execution_result = self.shell.run_cell(code, silent=False, store_history=True) - signal.alarm(0) + if "error" in result_container: + return self._handle_execution_error(result_container["error"]) - return self._format_execution_result( - execution_result, stdout_capture.getvalue(), stderr_capture.getvalue() - ) + if "execution_result" in result_container: + return self._format_execution_result( + result_container["execution_result"], + result_container.get("stdout", ""), + result_container.get("stderr", ""), + ) - except (TimeoutError, KeyboardInterrupt, SystemExit) as e: - signal.alarm(0) - return self._handle_execution_error(e) - - finally: - self._cleanup_execution_environment(old_handler, old_stdout, old_stderr) + return self._handle_execution_error(RuntimeError("Unknown execution error")) def close(self) -> None: self.is_running = False diff --git a/strix/tools/python/python_manager.py b/strix/tools/python/python_manager.py index 576e3a5..73376ab 100644 --- a/strix/tools/python/python_manager.py +++ b/strix/tools/python/python_manager.py @@ -1,7 +1,5 @@ import atexit import contextlib -import signal -import sys import threading from typing import Any @@ -113,16 +111,6 @@ class PythonSessionManager: def _register_cleanup_handlers(self) -> None: atexit.register(self.close_all_sessions) - signal.signal(signal.SIGTERM, self._signal_handler) - signal.signal(signal.SIGINT, self._signal_handler) - - if hasattr(signal, "SIGHUP"): - signal.signal(signal.SIGHUP, self._signal_handler) - - def _signal_handler(self, _signum: int, _frame: Any) -> None: - self.close_all_sessions() - sys.exit(0) - _python_session_manager = PythonSessionManager() diff --git a/strix/tools/terminal/terminal_manager.py b/strix/tools/terminal/terminal_manager.py index 95014f0..320dd18 100644 --- a/strix/tools/terminal/terminal_manager.py +++ b/strix/tools/terminal/terminal_manager.py @@ -1,7 +1,5 @@ import atexit import contextlib -import signal -import sys import threading from typing import Any @@ -133,16 +131,6 @@ class TerminalManager: def _register_cleanup_handlers(self) -> None: atexit.register(self.close_all_sessions) - signal.signal(signal.SIGTERM, self._signal_handler) - signal.signal(signal.SIGINT, self._signal_handler) - - if hasattr(signal, "SIGHUP"): - signal.signal(signal.SIGHUP, self._signal_handler) - - def _signal_handler(self, _signum: int, _frame: Any) -> None: - self.close_all_sessions() - sys.exit(0) - _terminal_manager = TerminalManager()