fix(runtime): parallel tool execution and remove signal handlers

- Add ThreadPoolExecutor in agent_worker for parallel request execution
- Add request_id correlation to prevent response mismatch between concurrent requests
- Add background listener thread per agent to dispatch responses to correct futures
- Add --timeout argument for hard request timeout (default: 120s from config)
- Remove signal handlers from terminal_manager, python_manager, tab_manager (use atexit only)
- Replace SIGALRM timeout in python_instance with threading-based timeout

This fixes requests getting queued behind slow operations and timeouts.
This commit is contained in:
0xallam
2026-01-16 00:21:02 -08:00
committed by Ahmed Allam
parent 8dc6f1dc8f
commit 693ef16060
7 changed files with 144 additions and 101 deletions

View File

@@ -168,6 +168,8 @@ RUN /app/venv/bin/pip install -r /home/pentester/tools/jwt_tool/requirements.txt
RUN echo "# Sandbox Environment" > README.md
COPY strix/__init__.py strix/
COPY strix/config/ /app/strix/config/
COPY strix/utils/ /app/strix/utils/
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/

View File

@@ -297,11 +297,12 @@ class DockerRuntime(AbstractRuntime):
)
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120"
container.exec_run(
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
f"--host 0.0.0.0 --port {tool_server_port} &'",
f"--host 0.0.0.0 --port {tool_server_port} --timeout {execution_timeout} &'",
detach=True,
user="pentester",
)

View File

@@ -2,12 +2,16 @@ from __future__ import annotations
import argparse
import asyncio
import contextlib
import logging
import os
import queue as stdlib_queue
import signal
import sys
import threading
from multiprocessing import Process, Queue
from typing import Any
from uuid import uuid4
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, status
@@ -23,9 +27,16 @@ parser = argparse.ArgumentParser(description="Start Strix tool server")
parser.add_argument("--token", required=True, help="Authentication token")
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
parser.add_argument("--port", type=int, required=True, help="Port to bind to")
parser.add_argument(
"--timeout",
type=int,
default=120,
help="Hard timeout in seconds for each request execution (default: 120)",
)
args = parser.parse_args()
EXPECTED_TOKEN = args.token
REQUEST_TIMEOUT = args.timeout
app = FastAPI()
security = HTTPBearer()
@@ -34,6 +45,8 @@ security_dependency = Depends(security)
agent_processes: dict[str, dict[str, Any]] = {}
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
pending_responses: dict[str, dict[str, asyncio.Future[Any]]] = {}
agent_listeners: dict[str, dict[str, Any]] = {}
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
@@ -72,37 +85,79 @@ def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queu
root_logger.handlers = [null_handler]
root_logger.setLevel(logging.CRITICAL)
from concurrent.futures import ThreadPoolExecutor
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
from strix.tools.registry import get_tool_by_name
while True:
try:
request = request_queue.get()
def _execute_request(request: dict[str, Any]) -> None:
request_id = request.get("request_id", "")
tool_name = request["tool_name"]
kwargs = request["kwargs"]
if request is None:
try:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
response_queue.put(
{"request_id": request_id, "error": f"Tool '{tool_name}' not found"}
)
return
converted_kwargs = convert_arguments(tool_func, kwargs)
result = tool_func(**converted_kwargs)
response_queue.put({"request_id": request_id, "result": result})
except (ArgumentConversionError, ValidationError) as e:
response_queue.put({"request_id": request_id, "error": f"Invalid arguments: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"request_id": request_id, "error": f"Tool execution error: {e}"})
except Exception as e: # noqa: BLE001
response_queue.put({"request_id": request_id, "error": f"Unexpected error: {e}"})
with ThreadPoolExecutor() as executor:
while True:
try:
request = request_queue.get()
if request is None:
break
executor.submit(_execute_request, request)
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Worker error: {e}"})
def _ensure_response_listener(agent_id: str, response_queue: Queue[Any]) -> None:
if agent_id in agent_listeners:
return
stop_event = threading.Event()
loop = asyncio.get_event_loop()
def _listener() -> None:
while not stop_event.is_set():
try:
item = response_queue.get(timeout=0.5)
except stdlib_queue.Empty:
continue
except (BrokenPipeError, EOFError):
break
tool_name = request["tool_name"]
kwargs = request["kwargs"]
request_id = item.get("request_id")
if not request_id or agent_id not in pending_responses:
continue
try:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
response_queue.put({"error": f"Tool '{tool_name}' not found"})
continue
future = pending_responses[agent_id].pop(request_id, None)
if future and not future.done():
with contextlib.suppress(RuntimeError):
loop.call_soon_threadsafe(future.set_result, item)
converted_kwargs = convert_arguments(tool_func, kwargs)
result = tool_func(**converted_kwargs)
listener_thread = threading.Thread(target=_listener, daemon=True)
listener_thread.start()
response_queue.put({"result": result})
except (ArgumentConversionError, ValidationError) as e:
response_queue.put({"error": f"Invalid arguments: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Tool execution error: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Worker error: {e}"})
agent_listeners[agent_id] = {"thread": listener_thread, "stop_event": stop_event}
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
@@ -117,6 +172,9 @@ def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
agent_processes[agent_id] = {"process": process, "pid": process.pid}
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
pending_responses[agent_id] = {}
_ensure_response_listener(agent_id, response_queue)
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
@@ -127,18 +185,31 @@ async def execute_tool(
) -> ToolExecutionResponse:
verify_token(credentials)
request_queue, response_queue = ensure_agent_process(request.agent_id)
request_queue, _response_queue = ensure_agent_process(request.agent_id)
request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs})
loop = asyncio.get_event_loop()
req_id = uuid4().hex
future: asyncio.Future[Any] = loop.create_future()
pending_responses[request.agent_id][req_id] = future
request_queue.put(
{
"request_id": req_id,
"tool_name": request.tool_name,
"kwargs": request.kwargs,
}
)
try:
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, response_queue.get)
response = await asyncio.wait_for(future, timeout=REQUEST_TIMEOUT)
if "error" in response:
return ToolExecutionResponse(error=response["error"])
return ToolExecutionResponse(result=response.get("result"))
except TimeoutError:
pending_responses[request.agent_id].pop(req_id, None)
return ToolExecutionResponse(error=f"Request timed out after {REQUEST_TIMEOUT} seconds")
except (RuntimeError, ValueError, OSError) as e:
return ToolExecutionResponse(error=f"Worker error: {e}")
@@ -168,6 +239,9 @@ async def health_check() -> dict[str, Any]:
def cleanup_all_agents() -> None:
for agent_id in list(agent_processes.keys()):
try:
if agent_id in agent_listeners:
agent_listeners[agent_id]["stop_event"].set()
agent_queues[agent_id]["request"].put(None)
process = agent_processes[agent_id]["process"]
@@ -180,6 +254,10 @@ def cleanup_all_agents() -> None:
if process.is_alive():
process.kill()
if agent_id in agent_listeners:
listener_thread = agent_listeners[agent_id]["thread"]
listener_thread.join(timeout=0.5)
except (BrokenPipeError, EOFError, OSError):
pass
except (RuntimeError, ValueError) as e:

View File

@@ -1,7 +1,5 @@
import atexit
import contextlib
import signal
import sys
import threading
from typing import Any
@@ -324,16 +322,6 @@ class BrowserTabManager:
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all()
sys.exit(0)
_browser_tab_manager = BrowserTabManager()

View File

@@ -1,5 +1,4 @@
import io
import signal
import sys
import threading
from typing import Any
@@ -57,28 +56,6 @@ class PythonInstance:
}
return None
def _setup_execution_environment(self, timeout: int) -> tuple[Any, io.StringIO, io.StringIO]:
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
def timeout_handler(signum: int, frame: Any) -> None:
raise TimeoutError(f"Code execution timed out after {timeout} seconds")
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
sys.stdout = stdout_capture
sys.stderr = stderr_capture
return old_handler, stdout_capture, stderr_capture
def _cleanup_execution_environment(
self, old_handler: Any, old_stdout: Any, old_stderr: Any
) -> None:
signal.signal(signal.SIGALRM, old_handler)
sys.stdout = old_stdout
sys.stderr = old_stderr
def _truncate_output(self, content: str, max_length: int, suffix: str) -> str:
if len(content) > max_length:
return content[:max_length] + suffix
@@ -142,27 +119,48 @@ class PythonInstance:
return session_error
with self._execution_lock:
result_container: dict[str, Any] = {}
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
old_stdout, old_stderr = sys.stdout, sys.stderr
try:
old_handler, stdout_capture, stderr_capture = self._setup_execution_environment(
timeout
def _run_code() -> None:
try:
sys.stdout = stdout_capture
sys.stderr = stderr_capture
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
result_container["execution_result"] = execution_result
result_container["stdout"] = stdout_capture.getvalue()
result_container["stderr"] = stderr_capture.getvalue()
except (KeyboardInterrupt, SystemExit) as e:
result_container["error"] = e
except Exception as e: # noqa: BLE001
result_container["error"] = e
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
exec_thread = threading.Thread(target=_run_code, daemon=True)
exec_thread.start()
exec_thread.join(timeout=timeout)
if exec_thread.is_alive():
return self._handle_execution_error(
TimeoutError(f"Code execution timed out after {timeout} seconds")
)
try:
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
signal.alarm(0)
if "error" in result_container:
return self._handle_execution_error(result_container["error"])
return self._format_execution_result(
execution_result, stdout_capture.getvalue(), stderr_capture.getvalue()
)
if "execution_result" in result_container:
return self._format_execution_result(
result_container["execution_result"],
result_container.get("stdout", ""),
result_container.get("stderr", ""),
)
except (TimeoutError, KeyboardInterrupt, SystemExit) as e:
signal.alarm(0)
return self._handle_execution_error(e)
finally:
self._cleanup_execution_environment(old_handler, old_stdout, old_stderr)
return self._handle_execution_error(RuntimeError("Unknown execution error"))
def close(self) -> None:
self.is_running = False

View File

@@ -1,7 +1,5 @@
import atexit
import contextlib
import signal
import sys
import threading
from typing import Any
@@ -113,16 +111,6 @@ class PythonSessionManager:
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all_sessions)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all_sessions()
sys.exit(0)
_python_session_manager = PythonSessionManager()

View File

@@ -1,7 +1,5 @@
import atexit
import contextlib
import signal
import sys
import threading
from typing import Any
@@ -133,16 +131,6 @@ class TerminalManager:
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all_sessions)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all_sessions()
sys.exit(0)
_terminal_manager = TerminalManager()