fix(runtime): parallel tool execution and remove signal handlers

- Add ThreadPoolExecutor in agent_worker for parallel request execution
- Add request_id correlation to prevent response mismatch between concurrent requests
- Add background listener thread per agent to dispatch responses to correct futures
- Add --timeout argument for hard request timeout (default: 120s from config)
- Remove signal handlers from terminal_manager, python_manager, tab_manager (use atexit only)
- Replace SIGALRM timeout in python_instance with threading-based timeout

This fixes requests getting queued behind slow operations and timeouts.
This commit is contained in:
0xallam
2026-01-16 00:21:02 -08:00
committed by Ahmed Allam
parent 8dc6f1dc8f
commit 693ef16060
7 changed files with 144 additions and 101 deletions

View File

@@ -168,6 +168,8 @@ RUN /app/venv/bin/pip install -r /home/pentester/tools/jwt_tool/requirements.txt
RUN echo "# Sandbox Environment" > README.md RUN echo "# Sandbox Environment" > README.md
COPY strix/__init__.py strix/ COPY strix/__init__.py strix/
COPY strix/config/ /app/strix/config/
COPY strix/utils/ /app/strix/utils/
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/ COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/ COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/

View File

@@ -297,11 +297,12 @@ class DockerRuntime(AbstractRuntime):
) )
caido_token = result.output.decode().strip() if result.exit_code == 0 else "" caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120"
container.exec_run( container.exec_run(
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && " f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} " f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} " f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
f"--host 0.0.0.0 --port {tool_server_port} &'", f"--host 0.0.0.0 --port {tool_server_port} --timeout {execution_timeout} &'",
detach=True, detach=True,
user="pentester", user="pentester",
) )

View File

@@ -2,12 +2,16 @@ from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
import contextlib
import logging import logging
import os import os
import queue as stdlib_queue
import signal import signal
import sys import sys
import threading
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from typing import Any from typing import Any
from uuid import uuid4
import uvicorn import uvicorn
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
@@ -23,9 +27,16 @@ parser = argparse.ArgumentParser(description="Start Strix tool server")
parser.add_argument("--token", required=True, help="Authentication token") parser.add_argument("--token", required=True, help="Authentication token")
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
parser.add_argument("--port", type=int, required=True, help="Port to bind to") parser.add_argument("--port", type=int, required=True, help="Port to bind to")
parser.add_argument(
"--timeout",
type=int,
default=120,
help="Hard timeout in seconds for each request execution (default: 120)",
)
args = parser.parse_args() args = parser.parse_args()
EXPECTED_TOKEN = args.token EXPECTED_TOKEN = args.token
REQUEST_TIMEOUT = args.timeout
app = FastAPI() app = FastAPI()
security = HTTPBearer() security = HTTPBearer()
@@ -34,6 +45,8 @@ security_dependency = Depends(security)
agent_processes: dict[str, dict[str, Any]] = {} agent_processes: dict[str, dict[str, Any]] = {}
agent_queues: dict[str, dict[str, Queue[Any]]] = {} agent_queues: dict[str, dict[str, Queue[Any]]] = {}
pending_responses: dict[str, dict[str, asyncio.Future[Any]]] = {}
agent_listeners: dict[str, dict[str, Any]] = {}
def verify_token(credentials: HTTPAuthorizationCredentials) -> str: def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
@@ -72,9 +85,37 @@ def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queu
root_logger.handlers = [null_handler] root_logger.handlers = [null_handler]
root_logger.setLevel(logging.CRITICAL) root_logger.setLevel(logging.CRITICAL)
from concurrent.futures import ThreadPoolExecutor
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
from strix.tools.registry import get_tool_by_name from strix.tools.registry import get_tool_by_name
def _execute_request(request: dict[str, Any]) -> None:
request_id = request.get("request_id", "")
tool_name = request["tool_name"]
kwargs = request["kwargs"]
try:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
response_queue.put(
{"request_id": request_id, "error": f"Tool '{tool_name}' not found"}
)
return
converted_kwargs = convert_arguments(tool_func, kwargs)
result = tool_func(**converted_kwargs)
response_queue.put({"request_id": request_id, "result": result})
except (ArgumentConversionError, ValidationError) as e:
response_queue.put({"request_id": request_id, "error": f"Invalid arguments: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"request_id": request_id, "error": f"Tool execution error: {e}"})
except Exception as e: # noqa: BLE001
response_queue.put({"request_id": request_id, "error": f"Unexpected error: {e}"})
with ThreadPoolExecutor() as executor:
while True: while True:
try: try:
request = request_queue.get() request = request_queue.get()
@@ -82,29 +123,43 @@ def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queu
if request is None: if request is None:
break break
tool_name = request["tool_name"] executor.submit(_execute_request, request)
kwargs = request["kwargs"]
try:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
response_queue.put({"error": f"Tool '{tool_name}' not found"})
continue
converted_kwargs = convert_arguments(tool_func, kwargs)
result = tool_func(**converted_kwargs)
response_queue.put({"result": result})
except (ArgumentConversionError, ValidationError) as e:
response_queue.put({"error": f"Invalid arguments: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Tool execution error: {e}"})
except (RuntimeError, ValueError, ImportError) as e: except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Worker error: {e}"}) response_queue.put({"error": f"Worker error: {e}"})
def _ensure_response_listener(agent_id: str, response_queue: Queue[Any]) -> None:
if agent_id in agent_listeners:
return
stop_event = threading.Event()
loop = asyncio.get_event_loop()
def _listener() -> None:
while not stop_event.is_set():
try:
item = response_queue.get(timeout=0.5)
except stdlib_queue.Empty:
continue
except (BrokenPipeError, EOFError):
break
request_id = item.get("request_id")
if not request_id or agent_id not in pending_responses:
continue
future = pending_responses[agent_id].pop(request_id, None)
if future and not future.done():
with contextlib.suppress(RuntimeError):
loop.call_soon_threadsafe(future.set_result, item)
listener_thread = threading.Thread(target=_listener, daemon=True)
listener_thread.start()
agent_listeners[agent_id] = {"thread": listener_thread, "stop_event": stop_event}
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]: def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
if agent_id not in agent_processes: if agent_id not in agent_processes:
request_queue: Queue[Any] = Queue() request_queue: Queue[Any] = Queue()
@@ -117,6 +172,9 @@ def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
agent_processes[agent_id] = {"process": process, "pid": process.pid} agent_processes[agent_id] = {"process": process, "pid": process.pid}
agent_queues[agent_id] = {"request": request_queue, "response": response_queue} agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
pending_responses[agent_id] = {}
_ensure_response_listener(agent_id, response_queue)
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"] return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
@@ -127,18 +185,31 @@ async def execute_tool(
) -> ToolExecutionResponse: ) -> ToolExecutionResponse:
verify_token(credentials) verify_token(credentials)
request_queue, response_queue = ensure_agent_process(request.agent_id) request_queue, _response_queue = ensure_agent_process(request.agent_id)
request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs}) loop = asyncio.get_event_loop()
req_id = uuid4().hex
future: asyncio.Future[Any] = loop.create_future()
pending_responses[request.agent_id][req_id] = future
request_queue.put(
{
"request_id": req_id,
"tool_name": request.tool_name,
"kwargs": request.kwargs,
}
)
try: try:
loop = asyncio.get_event_loop() response = await asyncio.wait_for(future, timeout=REQUEST_TIMEOUT)
response = await loop.run_in_executor(None, response_queue.get)
if "error" in response: if "error" in response:
return ToolExecutionResponse(error=response["error"]) return ToolExecutionResponse(error=response["error"])
return ToolExecutionResponse(result=response.get("result")) return ToolExecutionResponse(result=response.get("result"))
except TimeoutError:
pending_responses[request.agent_id].pop(req_id, None)
return ToolExecutionResponse(error=f"Request timed out after {REQUEST_TIMEOUT} seconds")
except (RuntimeError, ValueError, OSError) as e: except (RuntimeError, ValueError, OSError) as e:
return ToolExecutionResponse(error=f"Worker error: {e}") return ToolExecutionResponse(error=f"Worker error: {e}")
@@ -168,6 +239,9 @@ async def health_check() -> dict[str, Any]:
def cleanup_all_agents() -> None: def cleanup_all_agents() -> None:
for agent_id in list(agent_processes.keys()): for agent_id in list(agent_processes.keys()):
try: try:
if agent_id in agent_listeners:
agent_listeners[agent_id]["stop_event"].set()
agent_queues[agent_id]["request"].put(None) agent_queues[agent_id]["request"].put(None)
process = agent_processes[agent_id]["process"] process = agent_processes[agent_id]["process"]
@@ -180,6 +254,10 @@ def cleanup_all_agents() -> None:
if process.is_alive(): if process.is_alive():
process.kill() process.kill()
if agent_id in agent_listeners:
listener_thread = agent_listeners[agent_id]["thread"]
listener_thread.join(timeout=0.5)
except (BrokenPipeError, EOFError, OSError): except (BrokenPipeError, EOFError, OSError):
pass pass
except (RuntimeError, ValueError) as e: except (RuntimeError, ValueError) as e:

View File

@@ -1,7 +1,5 @@
import atexit import atexit
import contextlib import contextlib
import signal
import sys
import threading import threading
from typing import Any from typing import Any
@@ -324,16 +322,6 @@ class BrowserTabManager:
def _register_cleanup_handlers(self) -> None: def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all) atexit.register(self.close_all)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all()
sys.exit(0)
_browser_tab_manager = BrowserTabManager() _browser_tab_manager = BrowserTabManager()

View File

@@ -1,5 +1,4 @@
import io import io
import signal
import sys import sys
import threading import threading
from typing import Any from typing import Any
@@ -57,28 +56,6 @@ class PythonInstance:
} }
return None return None
def _setup_execution_environment(self, timeout: int) -> tuple[Any, io.StringIO, io.StringIO]:
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
def timeout_handler(signum: int, frame: Any) -> None:
raise TimeoutError(f"Code execution timed out after {timeout} seconds")
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
sys.stdout = stdout_capture
sys.stderr = stderr_capture
return old_handler, stdout_capture, stderr_capture
def _cleanup_execution_environment(
self, old_handler: Any, old_stdout: Any, old_stderr: Any
) -> None:
signal.signal(signal.SIGALRM, old_handler)
sys.stdout = old_stdout
sys.stderr = old_stderr
def _truncate_output(self, content: str, max_length: int, suffix: str) -> str: def _truncate_output(self, content: str, max_length: int, suffix: str) -> str:
if len(content) > max_length: if len(content) > max_length:
return content[:max_length] + suffix return content[:max_length] + suffix
@@ -142,27 +119,48 @@ class PythonInstance:
return session_error return session_error
with self._execution_lock: with self._execution_lock:
result_container: dict[str, Any] = {}
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
old_stdout, old_stderr = sys.stdout, sys.stderr old_stdout, old_stderr = sys.stdout, sys.stderr
def _run_code() -> None:
try: try:
old_handler, stdout_capture, stderr_capture = self._setup_execution_environment( sys.stdout = stdout_capture
timeout sys.stderr = stderr_capture
)
try:
execution_result = self.shell.run_cell(code, silent=False, store_history=True) execution_result = self.shell.run_cell(code, silent=False, store_history=True)
signal.alarm(0) result_container["execution_result"] = execution_result
result_container["stdout"] = stdout_capture.getvalue()
result_container["stderr"] = stderr_capture.getvalue()
except (KeyboardInterrupt, SystemExit) as e:
result_container["error"] = e
except Exception as e: # noqa: BLE001
result_container["error"] = e
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
return self._format_execution_result( exec_thread = threading.Thread(target=_run_code, daemon=True)
execution_result, stdout_capture.getvalue(), stderr_capture.getvalue() exec_thread.start()
exec_thread.join(timeout=timeout)
if exec_thread.is_alive():
return self._handle_execution_error(
TimeoutError(f"Code execution timed out after {timeout} seconds")
) )
except (TimeoutError, KeyboardInterrupt, SystemExit) as e: if "error" in result_container:
signal.alarm(0) return self._handle_execution_error(result_container["error"])
return self._handle_execution_error(e)
finally: if "execution_result" in result_container:
self._cleanup_execution_environment(old_handler, old_stdout, old_stderr) return self._format_execution_result(
result_container["execution_result"],
result_container.get("stdout", ""),
result_container.get("stderr", ""),
)
return self._handle_execution_error(RuntimeError("Unknown execution error"))
def close(self) -> None: def close(self) -> None:
self.is_running = False self.is_running = False

View File

@@ -1,7 +1,5 @@
import atexit import atexit
import contextlib import contextlib
import signal
import sys
import threading import threading
from typing import Any from typing import Any
@@ -113,16 +111,6 @@ class PythonSessionManager:
def _register_cleanup_handlers(self) -> None: def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all_sessions) atexit.register(self.close_all_sessions)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all_sessions()
sys.exit(0)
_python_session_manager = PythonSessionManager() _python_session_manager = PythonSessionManager()

View File

@@ -1,7 +1,5 @@
import atexit import atexit
import contextlib import contextlib
import signal
import sys
import threading import threading
from typing import Any from typing import Any
@@ -133,16 +131,6 @@ class TerminalManager:
def _register_cleanup_handlers(self) -> None: def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all_sessions) atexit.register(self.close_all_sessions)
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, self._signal_handler)
def _signal_handler(self, _signum: int, _frame: Any) -> None:
self.close_all_sessions()
sys.exit(0)
_terminal_manager = TerminalManager() _terminal_manager = TerminalManager()