fix(runtime): parallel tool execution and remove signal handlers
- Add ThreadPoolExecutor in agent_worker for parallel request execution - Add request_id correlation to prevent response mismatch between concurrent requests - Add background listener thread per agent to dispatch responses to correct futures - Add --timeout argument for hard request timeout (default: 120s from config) - Remove signal handlers from terminal_manager, python_manager, tab_manager (use atexit only) - Replace SIGALRM timeout in python_instance with threading-based timeout This fixes requests getting queued behind slow operations and timeouts.
This commit is contained in:
@@ -168,6 +168,8 @@ RUN /app/venv/bin/pip install -r /home/pentester/tools/jwt_tool/requirements.txt
|
||||
RUN echo "# Sandbox Environment" > README.md
|
||||
|
||||
COPY strix/__init__.py strix/
|
||||
COPY strix/config/ /app/strix/config/
|
||||
COPY strix/utils/ /app/strix/utils/
|
||||
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
|
||||
|
||||
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/
|
||||
|
||||
@@ -297,11 +297,12 @@ class DockerRuntime(AbstractRuntime):
|
||||
)
|
||||
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
|
||||
|
||||
execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120"
|
||||
container.exec_run(
|
||||
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
|
||||
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
|
||||
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
|
||||
f"--host 0.0.0.0 --port {tool_server_port} &'",
|
||||
f"--host 0.0.0.0 --port {tool_server_port} --timeout {execution_timeout} &'",
|
||||
detach=True,
|
||||
user="pentester",
|
||||
)
|
||||
|
||||
@@ -2,12 +2,16 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import queue as stdlib_queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from multiprocessing import Process, Queue
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
@@ -23,9 +27,16 @@ parser = argparse.ArgumentParser(description="Start Strix tool server")
|
||||
parser.add_argument("--token", required=True, help="Authentication token")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
|
||||
parser.add_argument("--port", type=int, required=True, help="Port to bind to")
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=120,
|
||||
help="Hard timeout in seconds for each request execution (default: 120)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
EXPECTED_TOKEN = args.token
|
||||
REQUEST_TIMEOUT = args.timeout
|
||||
|
||||
app = FastAPI()
|
||||
security = HTTPBearer()
|
||||
@@ -34,6 +45,8 @@ security_dependency = Depends(security)
|
||||
|
||||
agent_processes: dict[str, dict[str, Any]] = {}
|
||||
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
|
||||
pending_responses: dict[str, dict[str, asyncio.Future[Any]]] = {}
|
||||
agent_listeners: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
|
||||
@@ -72,9 +85,37 @@ def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queu
|
||||
root_logger.handlers = [null_handler]
|
||||
root_logger.setLevel(logging.CRITICAL)
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
|
||||
from strix.tools.registry import get_tool_by_name
|
||||
|
||||
def _execute_request(request: dict[str, Any]) -> None:
|
||||
request_id = request.get("request_id", "")
|
||||
tool_name = request["tool_name"]
|
||||
kwargs = request["kwargs"]
|
||||
|
||||
try:
|
||||
tool_func = get_tool_by_name(tool_name)
|
||||
if not tool_func:
|
||||
response_queue.put(
|
||||
{"request_id": request_id, "error": f"Tool '{tool_name}' not found"}
|
||||
)
|
||||
return
|
||||
|
||||
converted_kwargs = convert_arguments(tool_func, kwargs)
|
||||
result = tool_func(**converted_kwargs)
|
||||
|
||||
response_queue.put({"request_id": request_id, "result": result})
|
||||
|
||||
except (ArgumentConversionError, ValidationError) as e:
|
||||
response_queue.put({"request_id": request_id, "error": f"Invalid arguments: {e}"})
|
||||
except (RuntimeError, ValueError, ImportError) as e:
|
||||
response_queue.put({"request_id": request_id, "error": f"Tool execution error: {e}"})
|
||||
except Exception as e: # noqa: BLE001
|
||||
response_queue.put({"request_id": request_id, "error": f"Unexpected error: {e}"})
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
while True:
|
||||
try:
|
||||
request = request_queue.get()
|
||||
@@ -82,29 +123,43 @@ def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queu
|
||||
if request is None:
|
||||
break
|
||||
|
||||
tool_name = request["tool_name"]
|
||||
kwargs = request["kwargs"]
|
||||
|
||||
try:
|
||||
tool_func = get_tool_by_name(tool_name)
|
||||
if not tool_func:
|
||||
response_queue.put({"error": f"Tool '{tool_name}' not found"})
|
||||
continue
|
||||
|
||||
converted_kwargs = convert_arguments(tool_func, kwargs)
|
||||
result = tool_func(**converted_kwargs)
|
||||
|
||||
response_queue.put({"result": result})
|
||||
|
||||
except (ArgumentConversionError, ValidationError) as e:
|
||||
response_queue.put({"error": f"Invalid arguments: {e}"})
|
||||
except (RuntimeError, ValueError, ImportError) as e:
|
||||
response_queue.put({"error": f"Tool execution error: {e}"})
|
||||
executor.submit(_execute_request, request)
|
||||
|
||||
except (RuntimeError, ValueError, ImportError) as e:
|
||||
response_queue.put({"error": f"Worker error: {e}"})
|
||||
|
||||
|
||||
def _ensure_response_listener(agent_id: str, response_queue: Queue[Any]) -> None:
|
||||
if agent_id in agent_listeners:
|
||||
return
|
||||
|
||||
stop_event = threading.Event()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _listener() -> None:
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
item = response_queue.get(timeout=0.5)
|
||||
except stdlib_queue.Empty:
|
||||
continue
|
||||
except (BrokenPipeError, EOFError):
|
||||
break
|
||||
|
||||
request_id = item.get("request_id")
|
||||
if not request_id or agent_id not in pending_responses:
|
||||
continue
|
||||
|
||||
future = pending_responses[agent_id].pop(request_id, None)
|
||||
if future and not future.done():
|
||||
with contextlib.suppress(RuntimeError):
|
||||
loop.call_soon_threadsafe(future.set_result, item)
|
||||
|
||||
listener_thread = threading.Thread(target=_listener, daemon=True)
|
||||
listener_thread.start()
|
||||
|
||||
agent_listeners[agent_id] = {"thread": listener_thread, "stop_event": stop_event}
|
||||
|
||||
|
||||
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
|
||||
if agent_id not in agent_processes:
|
||||
request_queue: Queue[Any] = Queue()
|
||||
@@ -117,6 +172,9 @@ def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
|
||||
|
||||
agent_processes[agent_id] = {"process": process, "pid": process.pid}
|
||||
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
|
||||
pending_responses[agent_id] = {}
|
||||
|
||||
_ensure_response_listener(agent_id, response_queue)
|
||||
|
||||
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
|
||||
|
||||
@@ -127,18 +185,31 @@ async def execute_tool(
|
||||
) -> ToolExecutionResponse:
|
||||
verify_token(credentials)
|
||||
|
||||
request_queue, response_queue = ensure_agent_process(request.agent_id)
|
||||
request_queue, _response_queue = ensure_agent_process(request.agent_id)
|
||||
|
||||
request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs})
|
||||
loop = asyncio.get_event_loop()
|
||||
req_id = uuid4().hex
|
||||
future: asyncio.Future[Any] = loop.create_future()
|
||||
pending_responses[request.agent_id][req_id] = future
|
||||
|
||||
request_queue.put(
|
||||
{
|
||||
"request_id": req_id,
|
||||
"tool_name": request.tool_name,
|
||||
"kwargs": request.kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, response_queue.get)
|
||||
response = await asyncio.wait_for(future, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
if "error" in response:
|
||||
return ToolExecutionResponse(error=response["error"])
|
||||
return ToolExecutionResponse(result=response.get("result"))
|
||||
|
||||
except TimeoutError:
|
||||
pending_responses[request.agent_id].pop(req_id, None)
|
||||
return ToolExecutionResponse(error=f"Request timed out after {REQUEST_TIMEOUT} seconds")
|
||||
except (RuntimeError, ValueError, OSError) as e:
|
||||
return ToolExecutionResponse(error=f"Worker error: {e}")
|
||||
|
||||
@@ -168,6 +239,9 @@ async def health_check() -> dict[str, Any]:
|
||||
def cleanup_all_agents() -> None:
|
||||
for agent_id in list(agent_processes.keys()):
|
||||
try:
|
||||
if agent_id in agent_listeners:
|
||||
agent_listeners[agent_id]["stop_event"].set()
|
||||
|
||||
agent_queues[agent_id]["request"].put(None)
|
||||
process = agent_processes[agent_id]["process"]
|
||||
|
||||
@@ -180,6 +254,10 @@ def cleanup_all_agents() -> None:
|
||||
if process.is_alive():
|
||||
process.kill()
|
||||
|
||||
if agent_id in agent_listeners:
|
||||
listener_thread = agent_listeners[agent_id]["thread"]
|
||||
listener_thread.join(timeout=0.5)
|
||||
|
||||
except (BrokenPipeError, EOFError, OSError):
|
||||
pass
|
||||
except (RuntimeError, ValueError) as e:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
@@ -324,16 +322,6 @@ class BrowserTabManager:
|
||||
def _register_cleanup_handlers(self) -> None:
|
||||
atexit.register(self.close_all)
|
||||
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
||||
|
||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
||||
self.close_all()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_browser_tab_manager = BrowserTabManager()
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import io
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
@@ -57,28 +56,6 @@ class PythonInstance:
|
||||
}
|
||||
return None
|
||||
|
||||
def _setup_execution_environment(self, timeout: int) -> tuple[Any, io.StringIO, io.StringIO]:
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
|
||||
def timeout_handler(signum: int, frame: Any) -> None:
|
||||
raise TimeoutError(f"Code execution timed out after {timeout} seconds")
|
||||
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(timeout)
|
||||
|
||||
sys.stdout = stdout_capture
|
||||
sys.stderr = stderr_capture
|
||||
|
||||
return old_handler, stdout_capture, stderr_capture
|
||||
|
||||
def _cleanup_execution_environment(
|
||||
self, old_handler: Any, old_stdout: Any, old_stderr: Any
|
||||
) -> None:
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
def _truncate_output(self, content: str, max_length: int, suffix: str) -> str:
|
||||
if len(content) > max_length:
|
||||
return content[:max_length] + suffix
|
||||
@@ -142,27 +119,48 @@ class PythonInstance:
|
||||
return session_error
|
||||
|
||||
with self._execution_lock:
|
||||
result_container: dict[str, Any] = {}
|
||||
stdout_capture = io.StringIO()
|
||||
stderr_capture = io.StringIO()
|
||||
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
|
||||
def _run_code() -> None:
|
||||
try:
|
||||
old_handler, stdout_capture, stderr_capture = self._setup_execution_environment(
|
||||
timeout
|
||||
)
|
||||
|
||||
try:
|
||||
sys.stdout = stdout_capture
|
||||
sys.stderr = stderr_capture
|
||||
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
|
||||
signal.alarm(0)
|
||||
result_container["execution_result"] = execution_result
|
||||
result_container["stdout"] = stdout_capture.getvalue()
|
||||
result_container["stderr"] = stderr_capture.getvalue()
|
||||
except (KeyboardInterrupt, SystemExit) as e:
|
||||
result_container["error"] = e
|
||||
except Exception as e: # noqa: BLE001
|
||||
result_container["error"] = e
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
return self._format_execution_result(
|
||||
execution_result, stdout_capture.getvalue(), stderr_capture.getvalue()
|
||||
exec_thread = threading.Thread(target=_run_code, daemon=True)
|
||||
exec_thread.start()
|
||||
exec_thread.join(timeout=timeout)
|
||||
|
||||
if exec_thread.is_alive():
|
||||
return self._handle_execution_error(
|
||||
TimeoutError(f"Code execution timed out after {timeout} seconds")
|
||||
)
|
||||
|
||||
except (TimeoutError, KeyboardInterrupt, SystemExit) as e:
|
||||
signal.alarm(0)
|
||||
return self._handle_execution_error(e)
|
||||
if "error" in result_container:
|
||||
return self._handle_execution_error(result_container["error"])
|
||||
|
||||
finally:
|
||||
self._cleanup_execution_environment(old_handler, old_stdout, old_stderr)
|
||||
if "execution_result" in result_container:
|
||||
return self._format_execution_result(
|
||||
result_container["execution_result"],
|
||||
result_container.get("stdout", ""),
|
||||
result_container.get("stderr", ""),
|
||||
)
|
||||
|
||||
return self._handle_execution_error(RuntimeError("Unknown execution error"))
|
||||
|
||||
def close(self) -> None:
|
||||
self.is_running = False
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
@@ -113,16 +111,6 @@ class PythonSessionManager:
|
||||
def _register_cleanup_handlers(self) -> None:
|
||||
atexit.register(self.close_all_sessions)
|
||||
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
||||
|
||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
||||
self.close_all_sessions()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_python_session_manager = PythonSessionManager()
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
@@ -133,16 +131,6 @@ class TerminalManager:
|
||||
def _register_cleanup_handlers(self) -> None:
|
||||
atexit.register(self.close_all_sessions)
|
||||
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
||||
|
||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
||||
self.close_all_sessions()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_terminal_manager = TerminalManager()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user