Files
strix/strix/runtime/tool_server.py
0xallam 693ef16060 fix(runtime): parallel tool execution and remove signal handlers
- Add ThreadPoolExecutor in agent_worker for parallel request execution
- Add request_id correlation to prevent response mismatch between concurrent requests
- Add background listener thread per agent to dispatch responses to correct futures
- Add --timeout argument for hard request timeout (default: 120s from config)
- Remove signal handlers from terminal_manager, python_manager, tab_manager (use atexit only)
- Replace SIGALRM timeout in python_instance with threading-based timeout

This fixes requests getting queued behind slow operations and timeouts.
2026-01-16 01:11:02 -08:00

284 lines
9.1 KiB
Python

from __future__ import annotations
import argparse
import asyncio
import contextlib
import logging
import os
import queue as stdlib_queue
import signal
import sys
import threading
from multiprocessing import Process, Queue
from typing import Any
from uuid import uuid4
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, ValidationError
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not SANDBOX_MODE:
raise RuntimeError("Tool server should only run in sandbox mode (STRIX_SANDBOX_MODE=true)")
parser = argparse.ArgumentParser(description="Start Strix tool server")
parser.add_argument("--token", required=True, help="Authentication token")
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
parser.add_argument("--port", type=int, required=True, help="Port to bind to")
parser.add_argument(
"--timeout",
type=int,
default=120,
help="Hard timeout in seconds for each request execution (default: 120)",
)
args = parser.parse_args()
EXPECTED_TOKEN = args.token
REQUEST_TIMEOUT = args.timeout
app = FastAPI()
security = HTTPBearer()
security_dependency = Depends(security)
agent_processes: dict[str, dict[str, Any]] = {}
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
pending_responses: dict[str, dict[str, asyncio.Future[Any]]] = {}
agent_listeners: dict[str, dict[str, Any]] = {}
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
if not credentials or credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme. Bearer token required.",
headers={"WWW-Authenticate": "Bearer"},
)
if credentials.credentials != EXPECTED_TOKEN:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
return credentials.credentials
class ToolExecutionRequest(BaseModel):
agent_id: str
tool_name: str
kwargs: dict[str, Any]
class ToolExecutionResponse(BaseModel):
result: Any | None = None
error: str | None = None
def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
null_handler = logging.NullHandler()
root_logger = logging.getLogger()
root_logger.handlers = [null_handler]
root_logger.setLevel(logging.CRITICAL)
from concurrent.futures import ThreadPoolExecutor
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
from strix.tools.registry import get_tool_by_name
def _execute_request(request: dict[str, Any]) -> None:
request_id = request.get("request_id", "")
tool_name = request["tool_name"]
kwargs = request["kwargs"]
try:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
response_queue.put(
{"request_id": request_id, "error": f"Tool '{tool_name}' not found"}
)
return
converted_kwargs = convert_arguments(tool_func, kwargs)
result = tool_func(**converted_kwargs)
response_queue.put({"request_id": request_id, "result": result})
except (ArgumentConversionError, ValidationError) as e:
response_queue.put({"request_id": request_id, "error": f"Invalid arguments: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"request_id": request_id, "error": f"Tool execution error: {e}"})
except Exception as e: # noqa: BLE001
response_queue.put({"request_id": request_id, "error": f"Unexpected error: {e}"})
with ThreadPoolExecutor() as executor:
while True:
try:
request = request_queue.get()
if request is None:
break
executor.submit(_execute_request, request)
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"error": f"Worker error: {e}"})
def _ensure_response_listener(agent_id: str, response_queue: Queue[Any]) -> None:
if agent_id in agent_listeners:
return
stop_event = threading.Event()
loop = asyncio.get_event_loop()
def _listener() -> None:
while not stop_event.is_set():
try:
item = response_queue.get(timeout=0.5)
except stdlib_queue.Empty:
continue
except (BrokenPipeError, EOFError):
break
request_id = item.get("request_id")
if not request_id or agent_id not in pending_responses:
continue
future = pending_responses[agent_id].pop(request_id, None)
if future and not future.done():
with contextlib.suppress(RuntimeError):
loop.call_soon_threadsafe(future.set_result, item)
listener_thread = threading.Thread(target=_listener, daemon=True)
listener_thread.start()
agent_listeners[agent_id] = {"thread": listener_thread, "stop_event": stop_event}
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
if agent_id not in agent_processes:
request_queue: Queue[Any] = Queue()
response_queue: Queue[Any] = Queue()
process = Process(
target=agent_worker, args=(agent_id, request_queue, response_queue), daemon=True
)
process.start()
agent_processes[agent_id] = {"process": process, "pid": process.pid}
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
pending_responses[agent_id] = {}
_ensure_response_listener(agent_id, response_queue)
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
@app.post("/execute", response_model=ToolExecutionResponse)
async def execute_tool(
request: ToolExecutionRequest, credentials: HTTPAuthorizationCredentials = security_dependency
) -> ToolExecutionResponse:
verify_token(credentials)
request_queue, _response_queue = ensure_agent_process(request.agent_id)
loop = asyncio.get_event_loop()
req_id = uuid4().hex
future: asyncio.Future[Any] = loop.create_future()
pending_responses[request.agent_id][req_id] = future
request_queue.put(
{
"request_id": req_id,
"tool_name": request.tool_name,
"kwargs": request.kwargs,
}
)
try:
response = await asyncio.wait_for(future, timeout=REQUEST_TIMEOUT)
if "error" in response:
return ToolExecutionResponse(error=response["error"])
return ToolExecutionResponse(result=response.get("result"))
except TimeoutError:
pending_responses[request.agent_id].pop(req_id, None)
return ToolExecutionResponse(error=f"Request timed out after {REQUEST_TIMEOUT} seconds")
except (RuntimeError, ValueError, OSError) as e:
return ToolExecutionResponse(error=f"Worker error: {e}")
@app.post("/register_agent")
async def register_agent(
agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency
) -> dict[str, str]:
verify_token(credentials)
ensure_agent_process(agent_id)
return {"status": "registered", "agent_id": agent_id}
@app.get("/health")
async def health_check() -> dict[str, Any]:
return {
"status": "healthy",
"sandbox_mode": str(SANDBOX_MODE),
"environment": "sandbox" if SANDBOX_MODE else "main",
"auth_configured": "true" if EXPECTED_TOKEN else "false",
"active_agents": len(agent_processes),
"agents": list(agent_processes.keys()),
}
def cleanup_all_agents() -> None:
for agent_id in list(agent_processes.keys()):
try:
if agent_id in agent_listeners:
agent_listeners[agent_id]["stop_event"].set()
agent_queues[agent_id]["request"].put(None)
process = agent_processes[agent_id]["process"]
process.join(timeout=1)
if process.is_alive():
process.terminate()
process.join(timeout=1)
if process.is_alive():
process.kill()
if agent_id in agent_listeners:
listener_thread = agent_listeners[agent_id]["thread"]
listener_thread.join(timeout=0.5)
except (BrokenPipeError, EOFError, OSError):
pass
except (RuntimeError, ValueError) as e:
logging.getLogger(__name__).debug(f"Error during agent cleanup: {e}")
def signal_handler(_signum: int, _frame: Any) -> None:
signal.signal(signal.SIGPIPE, signal.SIG_IGN) if hasattr(signal, "SIGPIPE") else None
cleanup_all_agents()
sys.exit(0)
if hasattr(signal, "SIGPIPE"):
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if __name__ == "__main__":
try:
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
finally:
cleanup_all_agents()