fix(runtime): parallel tool execution and remove signal handlers
- Add ThreadPoolExecutor in agent_worker for parallel request execution - Add request_id correlation to prevent response mismatch between concurrent requests - Add background listener thread per agent to dispatch responses to correct futures - Add --timeout argument for hard request timeout (default: 120s from config) - Remove signal handlers from terminal_manager, python_manager, tab_manager (use atexit only) - Replace SIGALRM timeout in python_instance with threading-based timeout This fixes requests getting queued behind slow operations and timeouts.
This commit is contained in:
@@ -168,6 +168,8 @@ RUN /app/venv/bin/pip install -r /home/pentester/tools/jwt_tool/requirements.txt
|
|||||||
RUN echo "# Sandbox Environment" > README.md
|
RUN echo "# Sandbox Environment" > README.md
|
||||||
|
|
||||||
COPY strix/__init__.py strix/
|
COPY strix/__init__.py strix/
|
||||||
|
COPY strix/config/ /app/strix/config/
|
||||||
|
COPY strix/utils/ /app/strix/utils/
|
||||||
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
|
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
|
||||||
|
|
||||||
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/
|
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/
|
||||||
|
|||||||
@@ -297,11 +297,12 @@ class DockerRuntime(AbstractRuntime):
|
|||||||
)
|
)
|
||||||
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
|
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
|
||||||
|
|
||||||
|
execution_timeout = Config.get("strix_sandbox_execution_timeout") or "120"
|
||||||
container.exec_run(
|
container.exec_run(
|
||||||
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
|
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
|
||||||
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
|
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
|
||||||
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
|
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
|
||||||
f"--host 0.0.0.0 --port {tool_server_port} &'",
|
f"--host 0.0.0.0 --port {tool_server_port} --timeout {execution_timeout} &'",
|
||||||
detach=True,
|
detach=True,
|
||||||
user="pentester",
|
user="pentester",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,12 +2,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import queue as stdlib_queue
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from multiprocessing import Process, Queue
|
from multiprocessing import Process, Queue
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
@@ -23,9 +27,16 @@ parser = argparse.ArgumentParser(description="Start Strix tool server")
|
|||||||
parser.add_argument("--token", required=True, help="Authentication token")
|
parser.add_argument("--token", required=True, help="Authentication token")
|
||||||
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
|
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
|
||||||
parser.add_argument("--port", type=int, required=True, help="Port to bind to")
|
parser.add_argument("--port", type=int, required=True, help="Port to bind to")
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=int,
|
||||||
|
default=120,
|
||||||
|
help="Hard timeout in seconds for each request execution (default: 120)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
EXPECTED_TOKEN = args.token
|
EXPECTED_TOKEN = args.token
|
||||||
|
REQUEST_TIMEOUT = args.timeout
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
@@ -34,6 +45,8 @@ security_dependency = Depends(security)
|
|||||||
|
|
||||||
agent_processes: dict[str, dict[str, Any]] = {}
|
agent_processes: dict[str, dict[str, Any]] = {}
|
||||||
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
|
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
|
||||||
|
pending_responses: dict[str, dict[str, asyncio.Future[Any]]] = {}
|
||||||
|
agent_listeners: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
|
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
|
||||||
@@ -72,37 +85,79 @@ def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queu
|
|||||||
root_logger.handlers = [null_handler]
|
root_logger.handlers = [null_handler]
|
||||||
root_logger.setLevel(logging.CRITICAL)
|
root_logger.setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
|
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
|
||||||
from strix.tools.registry import get_tool_by_name
|
from strix.tools.registry import get_tool_by_name
|
||||||
|
|
||||||
while True:
|
def _execute_request(request: dict[str, Any]) -> None:
|
||||||
try:
|
request_id = request.get("request_id", "")
|
||||||
request = request_queue.get()
|
tool_name = request["tool_name"]
|
||||||
|
kwargs = request["kwargs"]
|
||||||
|
|
||||||
if request is None:
|
try:
|
||||||
|
tool_func = get_tool_by_name(tool_name)
|
||||||
|
if not tool_func:
|
||||||
|
response_queue.put(
|
||||||
|
{"request_id": request_id, "error": f"Tool '{tool_name}' not found"}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
converted_kwargs = convert_arguments(tool_func, kwargs)
|
||||||
|
result = tool_func(**converted_kwargs)
|
||||||
|
|
||||||
|
response_queue.put({"request_id": request_id, "result": result})
|
||||||
|
|
||||||
|
except (ArgumentConversionError, ValidationError) as e:
|
||||||
|
response_queue.put({"request_id": request_id, "error": f"Invalid arguments: {e}"})
|
||||||
|
except (RuntimeError, ValueError, ImportError) as e:
|
||||||
|
response_queue.put({"request_id": request_id, "error": f"Tool execution error: {e}"})
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
response_queue.put({"request_id": request_id, "error": f"Unexpected error: {e}"})
|
||||||
|
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
request = request_queue.get()
|
||||||
|
|
||||||
|
if request is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
executor.submit(_execute_request, request)
|
||||||
|
|
||||||
|
except (RuntimeError, ValueError, ImportError) as e:
|
||||||
|
response_queue.put({"error": f"Worker error: {e}"})
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_response_listener(agent_id: str, response_queue: Queue[Any]) -> None:
|
||||||
|
if agent_id in agent_listeners:
|
||||||
|
return
|
||||||
|
|
||||||
|
stop_event = threading.Event()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
def _listener() -> None:
|
||||||
|
while not stop_event.is_set():
|
||||||
|
try:
|
||||||
|
item = response_queue.get(timeout=0.5)
|
||||||
|
except stdlib_queue.Empty:
|
||||||
|
continue
|
||||||
|
except (BrokenPipeError, EOFError):
|
||||||
break
|
break
|
||||||
|
|
||||||
tool_name = request["tool_name"]
|
request_id = item.get("request_id")
|
||||||
kwargs = request["kwargs"]
|
if not request_id or agent_id not in pending_responses:
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
future = pending_responses[agent_id].pop(request_id, None)
|
||||||
tool_func = get_tool_by_name(tool_name)
|
if future and not future.done():
|
||||||
if not tool_func:
|
with contextlib.suppress(RuntimeError):
|
||||||
response_queue.put({"error": f"Tool '{tool_name}' not found"})
|
loop.call_soon_threadsafe(future.set_result, item)
|
||||||
continue
|
|
||||||
|
|
||||||
converted_kwargs = convert_arguments(tool_func, kwargs)
|
listener_thread = threading.Thread(target=_listener, daemon=True)
|
||||||
result = tool_func(**converted_kwargs)
|
listener_thread.start()
|
||||||
|
|
||||||
response_queue.put({"result": result})
|
agent_listeners[agent_id] = {"thread": listener_thread, "stop_event": stop_event}
|
||||||
|
|
||||||
except (ArgumentConversionError, ValidationError) as e:
|
|
||||||
response_queue.put({"error": f"Invalid arguments: {e}"})
|
|
||||||
except (RuntimeError, ValueError, ImportError) as e:
|
|
||||||
response_queue.put({"error": f"Tool execution error: {e}"})
|
|
||||||
|
|
||||||
except (RuntimeError, ValueError, ImportError) as e:
|
|
||||||
response_queue.put({"error": f"Worker error: {e}"})
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
|
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
|
||||||
@@ -117,6 +172,9 @@ def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
|
|||||||
|
|
||||||
agent_processes[agent_id] = {"process": process, "pid": process.pid}
|
agent_processes[agent_id] = {"process": process, "pid": process.pid}
|
||||||
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
|
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
|
||||||
|
pending_responses[agent_id] = {}
|
||||||
|
|
||||||
|
_ensure_response_listener(agent_id, response_queue)
|
||||||
|
|
||||||
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
|
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
|
||||||
|
|
||||||
@@ -127,18 +185,31 @@ async def execute_tool(
|
|||||||
) -> ToolExecutionResponse:
|
) -> ToolExecutionResponse:
|
||||||
verify_token(credentials)
|
verify_token(credentials)
|
||||||
|
|
||||||
request_queue, response_queue = ensure_agent_process(request.agent_id)
|
request_queue, _response_queue = ensure_agent_process(request.agent_id)
|
||||||
|
|
||||||
request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs})
|
loop = asyncio.get_event_loop()
|
||||||
|
req_id = uuid4().hex
|
||||||
|
future: asyncio.Future[Any] = loop.create_future()
|
||||||
|
pending_responses[request.agent_id][req_id] = future
|
||||||
|
|
||||||
|
request_queue.put(
|
||||||
|
{
|
||||||
|
"request_id": req_id,
|
||||||
|
"tool_name": request.tool_name,
|
||||||
|
"kwargs": request.kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
response = await asyncio.wait_for(future, timeout=REQUEST_TIMEOUT)
|
||||||
response = await loop.run_in_executor(None, response_queue.get)
|
|
||||||
|
|
||||||
if "error" in response:
|
if "error" in response:
|
||||||
return ToolExecutionResponse(error=response["error"])
|
return ToolExecutionResponse(error=response["error"])
|
||||||
return ToolExecutionResponse(result=response.get("result"))
|
return ToolExecutionResponse(result=response.get("result"))
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
pending_responses[request.agent_id].pop(req_id, None)
|
||||||
|
return ToolExecutionResponse(error=f"Request timed out after {REQUEST_TIMEOUT} seconds")
|
||||||
except (RuntimeError, ValueError, OSError) as e:
|
except (RuntimeError, ValueError, OSError) as e:
|
||||||
return ToolExecutionResponse(error=f"Worker error: {e}")
|
return ToolExecutionResponse(error=f"Worker error: {e}")
|
||||||
|
|
||||||
@@ -168,6 +239,9 @@ async def health_check() -> dict[str, Any]:
|
|||||||
def cleanup_all_agents() -> None:
|
def cleanup_all_agents() -> None:
|
||||||
for agent_id in list(agent_processes.keys()):
|
for agent_id in list(agent_processes.keys()):
|
||||||
try:
|
try:
|
||||||
|
if agent_id in agent_listeners:
|
||||||
|
agent_listeners[agent_id]["stop_event"].set()
|
||||||
|
|
||||||
agent_queues[agent_id]["request"].put(None)
|
agent_queues[agent_id]["request"].put(None)
|
||||||
process = agent_processes[agent_id]["process"]
|
process = agent_processes[agent_id]["process"]
|
||||||
|
|
||||||
@@ -180,6 +254,10 @@ def cleanup_all_agents() -> None:
|
|||||||
if process.is_alive():
|
if process.is_alive():
|
||||||
process.kill()
|
process.kill()
|
||||||
|
|
||||||
|
if agent_id in agent_listeners:
|
||||||
|
listener_thread = agent_listeners[agent_id]["thread"]
|
||||||
|
listener_thread.join(timeout=0.5)
|
||||||
|
|
||||||
except (BrokenPipeError, EOFError, OSError):
|
except (BrokenPipeError, EOFError, OSError):
|
||||||
pass
|
pass
|
||||||
except (RuntimeError, ValueError) as e:
|
except (RuntimeError, ValueError) as e:
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import atexit
|
import atexit
|
||||||
import contextlib
|
import contextlib
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -324,16 +322,6 @@ class BrowserTabManager:
|
|||||||
def _register_cleanup_handlers(self) -> None:
|
def _register_cleanup_handlers(self) -> None:
|
||||||
atexit.register(self.close_all)
|
atexit.register(self.close_all)
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, self._signal_handler)
|
|
||||||
|
|
||||||
if hasattr(signal, "SIGHUP"):
|
|
||||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
|
||||||
|
|
||||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
|
||||||
self.close_all()
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
_browser_tab_manager = BrowserTabManager()
|
_browser_tab_manager = BrowserTabManager()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import io
|
import io
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -57,28 +56,6 @@ class PythonInstance:
|
|||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _setup_execution_environment(self, timeout: int) -> tuple[Any, io.StringIO, io.StringIO]:
|
|
||||||
stdout_capture = io.StringIO()
|
|
||||||
stderr_capture = io.StringIO()
|
|
||||||
|
|
||||||
def timeout_handler(signum: int, frame: Any) -> None:
|
|
||||||
raise TimeoutError(f"Code execution timed out after {timeout} seconds")
|
|
||||||
|
|
||||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
|
||||||
signal.alarm(timeout)
|
|
||||||
|
|
||||||
sys.stdout = stdout_capture
|
|
||||||
sys.stderr = stderr_capture
|
|
||||||
|
|
||||||
return old_handler, stdout_capture, stderr_capture
|
|
||||||
|
|
||||||
def _cleanup_execution_environment(
|
|
||||||
self, old_handler: Any, old_stdout: Any, old_stderr: Any
|
|
||||||
) -> None:
|
|
||||||
signal.signal(signal.SIGALRM, old_handler)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
sys.stderr = old_stderr
|
|
||||||
|
|
||||||
def _truncate_output(self, content: str, max_length: int, suffix: str) -> str:
|
def _truncate_output(self, content: str, max_length: int, suffix: str) -> str:
|
||||||
if len(content) > max_length:
|
if len(content) > max_length:
|
||||||
return content[:max_length] + suffix
|
return content[:max_length] + suffix
|
||||||
@@ -142,27 +119,48 @@ class PythonInstance:
|
|||||||
return session_error
|
return session_error
|
||||||
|
|
||||||
with self._execution_lock:
|
with self._execution_lock:
|
||||||
|
result_container: dict[str, Any] = {}
|
||||||
|
stdout_capture = io.StringIO()
|
||||||
|
stderr_capture = io.StringIO()
|
||||||
|
|
||||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||||
|
|
||||||
try:
|
def _run_code() -> None:
|
||||||
old_handler, stdout_capture, stderr_capture = self._setup_execution_environment(
|
try:
|
||||||
timeout
|
sys.stdout = stdout_capture
|
||||||
|
sys.stderr = stderr_capture
|
||||||
|
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
|
||||||
|
result_container["execution_result"] = execution_result
|
||||||
|
result_container["stdout"] = stdout_capture.getvalue()
|
||||||
|
result_container["stderr"] = stderr_capture.getvalue()
|
||||||
|
except (KeyboardInterrupt, SystemExit) as e:
|
||||||
|
result_container["error"] = e
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
result_container["error"] = e
|
||||||
|
finally:
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
sys.stderr = old_stderr
|
||||||
|
|
||||||
|
exec_thread = threading.Thread(target=_run_code, daemon=True)
|
||||||
|
exec_thread.start()
|
||||||
|
exec_thread.join(timeout=timeout)
|
||||||
|
|
||||||
|
if exec_thread.is_alive():
|
||||||
|
return self._handle_execution_error(
|
||||||
|
TimeoutError(f"Code execution timed out after {timeout} seconds")
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if "error" in result_container:
|
||||||
execution_result = self.shell.run_cell(code, silent=False, store_history=True)
|
return self._handle_execution_error(result_container["error"])
|
||||||
signal.alarm(0)
|
|
||||||
|
|
||||||
return self._format_execution_result(
|
if "execution_result" in result_container:
|
||||||
execution_result, stdout_capture.getvalue(), stderr_capture.getvalue()
|
return self._format_execution_result(
|
||||||
)
|
result_container["execution_result"],
|
||||||
|
result_container.get("stdout", ""),
|
||||||
|
result_container.get("stderr", ""),
|
||||||
|
)
|
||||||
|
|
||||||
except (TimeoutError, KeyboardInterrupt, SystemExit) as e:
|
return self._handle_execution_error(RuntimeError("Unknown execution error"))
|
||||||
signal.alarm(0)
|
|
||||||
return self._handle_execution_error(e)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
self._cleanup_execution_environment(old_handler, old_stdout, old_stderr)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import atexit
|
import atexit
|
||||||
import contextlib
|
import contextlib
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -113,16 +111,6 @@ class PythonSessionManager:
|
|||||||
def _register_cleanup_handlers(self) -> None:
|
def _register_cleanup_handlers(self) -> None:
|
||||||
atexit.register(self.close_all_sessions)
|
atexit.register(self.close_all_sessions)
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, self._signal_handler)
|
|
||||||
|
|
||||||
if hasattr(signal, "SIGHUP"):
|
|
||||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
|
||||||
|
|
||||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
|
||||||
self.close_all_sessions()
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
_python_session_manager = PythonSessionManager()
|
_python_session_manager = PythonSessionManager()
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import atexit
|
import atexit
|
||||||
import contextlib
|
import contextlib
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -133,16 +131,6 @@ class TerminalManager:
|
|||||||
def _register_cleanup_handlers(self) -> None:
|
def _register_cleanup_handlers(self) -> None:
|
||||||
atexit.register(self.close_all_sessions)
|
atexit.register(self.close_all_sessions)
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, self._signal_handler)
|
|
||||||
|
|
||||||
if hasattr(signal, "SIGHUP"):
|
|
||||||
signal.signal(signal.SIGHUP, self._signal_handler)
|
|
||||||
|
|
||||||
def _signal_handler(self, _signum: int, _frame: Any) -> None:
|
|
||||||
self.close_all_sessions()
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
_terminal_manager = TerminalManager()
|
_terminal_manager = TerminalManager()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user