refactor: simplify tool server to asyncio tasks with per-agent isolation

- Replace multiprocessing/threading with single asyncio task per agent
- Add task cancellation: new request cancels previous for same agent
- Add per-agent state isolation via ContextVar for Terminal, Browser, Python managers
- Add posthog telemetry for tool execution errors (timeout, http, sandbox)
- Fix proxy manager singleton pattern
- Increase client timeout buffer over server timeout
- Add context.py to Dockerfile
This commit is contained in:
0xallam
2026-01-17 21:41:57 -08:00
committed by Ahmed Allam
parent a80ecac7bd
commit 918a151892
8 changed files with 271 additions and 295 deletions

View File

@@ -172,7 +172,7 @@ COPY strix/config/ /app/strix/config/
COPY strix/utils/ /app/strix/utils/ COPY strix/utils/ /app/strix/utils/
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/ COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/ COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py strix/tools/context.py /app/strix/tools/
COPY strix/tools/browser/ /app/strix/tools/browser/ COPY strix/tools/browser/ /app/strix/tools/browser/
COPY strix/tools/file_edit/ /app/strix/tools/file_edit/ COPY strix/tools/file_edit/ /app/strix/tools/file_edit/

View File

@@ -2,16 +2,10 @@ from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
import contextlib
import logging
import os import os
import queue as stdlib_queue
import signal import signal
import sys import sys
import threading
from multiprocessing import Process, Queue
from typing import Any from typing import Any
from uuid import uuid4
import uvicorn import uvicorn
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
@@ -40,13 +34,9 @@ REQUEST_TIMEOUT = args.timeout
app = FastAPI() app = FastAPI()
security = HTTPBearer() security = HTTPBearer()
security_dependency = Depends(security) security_dependency = Depends(security)
agent_processes: dict[str, dict[str, Any]] = {} agent_tasks: dict[str, asyncio.Task[Any]] = {}
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
pending_responses: dict[str, dict[str, asyncio.Future[Any]]] = {}
agent_listeners: dict[str, dict[str, Any]] = {}
def verify_token(credentials: HTTPAuthorizationCredentials) -> str: def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
@@ -78,107 +68,19 @@ class ToolExecutionResponse(BaseModel):
error: str | None = None error: str | None = None
def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queue[Any]) -> None: async def _run_tool(agent_id: str, tool_name: str, kwargs: dict[str, Any]) -> Any:
null_handler = logging.NullHandler() from strix.tools.argument_parser import convert_arguments
from strix.tools.context import set_current_agent_id
root_logger = logging.getLogger()
root_logger.handlers = [null_handler]
root_logger.setLevel(logging.CRITICAL)
from concurrent.futures import ThreadPoolExecutor
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
from strix.tools.registry import get_tool_by_name from strix.tools.registry import get_tool_by_name
def _execute_request(request: dict[str, Any]) -> None: set_current_agent_id(agent_id)
request_id = request.get("request_id", "")
tool_name = request["tool_name"]
kwargs = request["kwargs"]
try:
tool_func = get_tool_by_name(tool_name) tool_func = get_tool_by_name(tool_name)
if not tool_func: if not tool_func:
response_queue.put( raise ValueError(f"Tool '{tool_name}' not found")
{"request_id": request_id, "error": f"Tool '{tool_name}' not found"}
)
return
converted_kwargs = convert_arguments(tool_func, kwargs) converted_kwargs = convert_arguments(tool_func, kwargs)
result = tool_func(**converted_kwargs) return await asyncio.to_thread(tool_func, **converted_kwargs)
response_queue.put({"request_id": request_id, "result": result})
except (ArgumentConversionError, ValidationError) as e:
response_queue.put({"request_id": request_id, "error": f"Invalid arguments: {e}"})
except (RuntimeError, ValueError, ImportError) as e:
response_queue.put({"request_id": request_id, "error": f"Tool execution error: {e}"})
except Exception as e: # noqa: BLE001
response_queue.put({"request_id": request_id, "error": f"Unexpected error: {e}"})
with ThreadPoolExecutor() as executor:
while True:
request = None
try:
request = request_queue.get()
if request is None:
break
executor.submit(_execute_request, request)
except (RuntimeError, ValueError, ImportError) as e:
req_id = request.get("request_id", "") if request else ""
response_queue.put({"request_id": req_id, "error": f"Worker error: {e}"})
def _ensure_response_listener(agent_id: str, response_queue: Queue[Any]) -> None:
if agent_id in agent_listeners:
return
stop_event = threading.Event()
loop = asyncio.get_running_loop()
def _listener() -> None:
while not stop_event.is_set():
try:
item = response_queue.get(timeout=0.5)
except stdlib_queue.Empty:
continue
except (BrokenPipeError, EOFError):
break
request_id = item.get("request_id")
if not request_id or agent_id not in pending_responses:
continue
future = pending_responses[agent_id].pop(request_id, None)
if future and not future.done():
with contextlib.suppress(RuntimeError):
loop.call_soon_threadsafe(future.set_result, item)
listener_thread = threading.Thread(target=_listener, daemon=True)
listener_thread.start()
agent_listeners[agent_id] = {"thread": listener_thread, "stop_event": stop_event}
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
if agent_id not in agent_processes:
request_queue: Queue[Any] = Queue()
response_queue: Queue[Any] = Queue()
process = Process(
target=agent_worker, args=(agent_id, request_queue, response_queue), daemon=True
)
process.start()
agent_processes[agent_id] = {"process": process, "pid": process.pid}
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
pending_responses[agent_id] = {}
_ensure_response_listener(agent_id, response_queue)
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
@app.post("/execute", response_model=ToolExecutionResponse) @app.post("/execute", response_model=ToolExecutionResponse)
@@ -187,33 +89,42 @@ async def execute_tool(
) -> ToolExecutionResponse: ) -> ToolExecutionResponse:
verify_token(credentials) verify_token(credentials)
request_queue, _response_queue = ensure_agent_process(request.agent_id) agent_id = request.agent_id
loop = asyncio.get_running_loop() if agent_id in agent_tasks:
req_id = uuid4().hex old_task = agent_tasks[agent_id]
future: asyncio.Future[Any] = loop.create_future() if not old_task.done():
pending_responses[request.agent_id][req_id] = future old_task.cancel()
request_queue.put( task = asyncio.create_task(
{ asyncio.wait_for(
"request_id": req_id, _run_tool(agent_id, request.tool_name, request.kwargs), timeout=REQUEST_TIMEOUT
"tool_name": request.tool_name,
"kwargs": request.kwargs,
}
) )
)
agent_tasks[agent_id] = task
try: try:
response = await asyncio.wait_for(future, timeout=REQUEST_TIMEOUT) result = await task
return ToolExecutionResponse(result=result)
if "error" in response: except asyncio.CancelledError:
return ToolExecutionResponse(error=response["error"]) return ToolExecutionResponse(error="Cancelled by newer request")
return ToolExecutionResponse(result=response.get("result"))
except TimeoutError: except TimeoutError:
pending_responses[request.agent_id].pop(req_id, None) return ToolExecutionResponse(error=f"Tool timed out after {REQUEST_TIMEOUT}s")
return ToolExecutionResponse(error=f"Request timed out after {REQUEST_TIMEOUT} seconds")
except (RuntimeError, ValueError, OSError) as e: except ValidationError as e:
return ToolExecutionResponse(error=f"Worker error: {e}") return ToolExecutionResponse(error=f"Invalid arguments: {e}")
except (ValueError, RuntimeError, ImportError) as e:
return ToolExecutionResponse(error=f"Tool execution error: {e}")
except Exception as e: # noqa: BLE001
return ToolExecutionResponse(error=f"Unexpected error: {e}")
finally:
if agent_tasks.get(agent_id) is task:
del agent_tasks[agent_id]
@app.post("/register_agent") @app.post("/register_agent")
@@ -221,8 +132,6 @@ async def register_agent(
agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency
) -> dict[str, str]: ) -> dict[str, str]:
verify_token(credentials) verify_token(credentials)
ensure_agent_process(agent_id)
return {"status": "registered", "agent_id": agent_id} return {"status": "registered", "agent_id": agent_id}
@@ -233,42 +142,16 @@ async def health_check() -> dict[str, Any]:
"sandbox_mode": str(SANDBOX_MODE), "sandbox_mode": str(SANDBOX_MODE),
"environment": "sandbox" if SANDBOX_MODE else "main", "environment": "sandbox" if SANDBOX_MODE else "main",
"auth_configured": "true" if EXPECTED_TOKEN else "false", "auth_configured": "true" if EXPECTED_TOKEN else "false",
"active_agents": len(agent_processes), "active_agents": len(agent_tasks),
"agents": list(agent_processes.keys()), "agents": list(agent_tasks.keys()),
} }
def cleanup_all_agents() -> None:
for agent_id in list(agent_processes.keys()):
try:
if agent_id in agent_listeners:
agent_listeners[agent_id]["stop_event"].set()
agent_queues[agent_id]["request"].put(None)
process = agent_processes[agent_id]["process"]
process.join(timeout=1)
if process.is_alive():
process.terminate()
process.join(timeout=1)
if process.is_alive():
process.kill()
if agent_id in agent_listeners:
listener_thread = agent_listeners[agent_id]["thread"]
listener_thread.join(timeout=0.5)
except (BrokenPipeError, EOFError, OSError):
pass
except (RuntimeError, ValueError) as e:
logging.getLogger(__name__).debug(f"Error during agent cleanup: {e}")
def signal_handler(_signum: int, _frame: Any) -> None: def signal_handler(_signum: int, _frame: Any) -> None:
signal.signal(signal.SIGPIPE, signal.SIG_IGN) if hasattr(signal, "SIGPIPE") else None if hasattr(signal, "SIGPIPE"):
cleanup_all_agents() signal.signal(signal.SIGPIPE, signal.SIG_IGN)
for task in agent_tasks.values():
task.cancel()
sys.exit(0) sys.exit(0)
@@ -279,7 +162,4 @@ signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
if __name__ == "__main__": if __name__ == "__main__":
try:
uvicorn.run(app, host=args.host, port=args.port, log_level="info") uvicorn.run(app, host=args.host, port=args.port, log_level="info")
finally:
cleanup_all_agents()

View File

@@ -3,39 +3,54 @@ import contextlib
import threading import threading
from typing import Any from typing import Any
from strix.tools.context import get_current_agent_id
from .browser_instance import BrowserInstance from .browser_instance import BrowserInstance
class BrowserTabManager: class BrowserTabManager:
def __init__(self) -> None: def __init__(self) -> None:
self.browser_instance: BrowserInstance | None = None self._browsers_by_agent: dict[str, BrowserInstance] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
self._register_cleanup_handlers() self._register_cleanup_handlers()
def _get_agent_browser(self) -> BrowserInstance | None:
agent_id = get_current_agent_id()
with self._lock:
return self._browsers_by_agent.get(agent_id)
def _set_agent_browser(self, browser: BrowserInstance | None) -> None:
agent_id = get_current_agent_id()
with self._lock:
if browser is None:
self._browsers_by_agent.pop(agent_id, None)
else:
self._browsers_by_agent[agent_id] = browser
def launch_browser(self, url: str | None = None) -> dict[str, Any]: def launch_browser(self, url: str | None = None) -> dict[str, Any]:
with self._lock: with self._lock:
if self.browser_instance is not None: agent_id = get_current_agent_id()
if agent_id in self._browsers_by_agent:
raise ValueError("Browser is already launched") raise ValueError("Browser is already launched")
try: try:
self.browser_instance = BrowserInstance() browser = BrowserInstance()
result = self.browser_instance.launch(url) result = browser.launch(url)
self._browsers_by_agent[agent_id] = browser
result["message"] = "Browser launched successfully" result["message"] = "Browser launched successfully"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
if self.browser_instance:
self.browser_instance = None
raise RuntimeError(f"Failed to launch browser: {e}") from e raise RuntimeError(f"Failed to launch browser: {e}") from e
else: else:
return result return result
def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]: def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.goto(url, tab_id) result = browser.goto(url, tab_id)
result["message"] = f"Navigated to {url}" result["message"] = f"Navigated to {url}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to navigate to URL: {e}") from e raise RuntimeError(f"Failed to navigate to URL: {e}") from e
@@ -43,12 +58,12 @@ class BrowserTabManager:
return result return result
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]: def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.click(coordinate, tab_id) result = browser.click(coordinate, tab_id)
result["message"] = f"Clicked at {coordinate}" result["message"] = f"Clicked at {coordinate}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to click: {e}") from e raise RuntimeError(f"Failed to click: {e}") from e
@@ -56,12 +71,12 @@ class BrowserTabManager:
return result return result
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]: def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.type_text(text, tab_id) result = browser.type_text(text, tab_id)
result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}" result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to type text: {e}") from e raise RuntimeError(f"Failed to type text: {e}") from e
@@ -69,12 +84,12 @@ class BrowserTabManager:
return result return result
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]: def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.scroll(direction, tab_id) result = browser.scroll(direction, tab_id)
result["message"] = f"Scrolled {direction}" result["message"] = f"Scrolled {direction}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to scroll: {e}") from e raise RuntimeError(f"Failed to scroll: {e}") from e
@@ -82,12 +97,12 @@ class BrowserTabManager:
return result return result
def back(self, tab_id: str | None = None) -> dict[str, Any]: def back(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.back(tab_id) result = browser.back(tab_id)
result["message"] = "Navigated back" result["message"] = "Navigated back"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to go back: {e}") from e raise RuntimeError(f"Failed to go back: {e}") from e
@@ -95,12 +110,12 @@ class BrowserTabManager:
return result return result
def forward(self, tab_id: str | None = None) -> dict[str, Any]: def forward(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.forward(tab_id) result = browser.forward(tab_id)
result["message"] = "Navigated forward" result["message"] = "Navigated forward"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to go forward: {e}") from e raise RuntimeError(f"Failed to go forward: {e}") from e
@@ -108,12 +123,12 @@ class BrowserTabManager:
return result return result
def new_tab(self, url: str | None = None) -> dict[str, Any]: def new_tab(self, url: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.new_tab(url) result = browser.new_tab(url)
result["message"] = f"Created new tab {result.get('tab_id', '')}" result["message"] = f"Created new tab {result.get('tab_id', '')}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to create new tab: {e}") from e raise RuntimeError(f"Failed to create new tab: {e}") from e
@@ -121,12 +136,12 @@ class BrowserTabManager:
return result return result
def switch_tab(self, tab_id: str) -> dict[str, Any]: def switch_tab(self, tab_id: str) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.switch_tab(tab_id) result = browser.switch_tab(tab_id)
result["message"] = f"Switched to tab {tab_id}" result["message"] = f"Switched to tab {tab_id}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to switch tab: {e}") from e raise RuntimeError(f"Failed to switch tab: {e}") from e
@@ -134,12 +149,12 @@ class BrowserTabManager:
return result return result
def close_tab(self, tab_id: str) -> dict[str, Any]: def close_tab(self, tab_id: str) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.close_tab(tab_id) result = browser.close_tab(tab_id)
result["message"] = f"Closed tab {tab_id}" result["message"] = f"Closed tab {tab_id}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to close tab: {e}") from e raise RuntimeError(f"Failed to close tab: {e}") from e
@@ -147,12 +162,12 @@ class BrowserTabManager:
return result return result
def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]: def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.wait(duration, tab_id) result = browser.wait(duration, tab_id)
result["message"] = f"Waited {duration}s" result["message"] = f"Waited {duration}s"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to wait: {e}") from e raise RuntimeError(f"Failed to wait: {e}") from e
@@ -160,12 +175,12 @@ class BrowserTabManager:
return result return result
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]: def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.execute_js(js_code, tab_id) result = browser.execute_js(js_code, tab_id)
result["message"] = "JavaScript executed successfully" result["message"] = "JavaScript executed successfully"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to execute JavaScript: {e}") from e raise RuntimeError(f"Failed to execute JavaScript: {e}") from e
@@ -173,12 +188,12 @@ class BrowserTabManager:
return result return result
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]: def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.double_click(coordinate, tab_id) result = browser.double_click(coordinate, tab_id)
result["message"] = f"Double clicked at {coordinate}" result["message"] = f"Double clicked at {coordinate}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to double click: {e}") from e raise RuntimeError(f"Failed to double click: {e}") from e
@@ -186,12 +201,12 @@ class BrowserTabManager:
return result return result
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]: def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.hover(coordinate, tab_id) result = browser.hover(coordinate, tab_id)
result["message"] = f"Hovered at {coordinate}" result["message"] = f"Hovered at {coordinate}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to hover: {e}") from e raise RuntimeError(f"Failed to hover: {e}") from e
@@ -199,12 +214,12 @@ class BrowserTabManager:
return result return result
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]: def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.press_key(key, tab_id) result = browser.press_key(key, tab_id)
result["message"] = f"Pressed key {key}" result["message"] = f"Pressed key {key}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to press key: {e}") from e raise RuntimeError(f"Failed to press key: {e}") from e
@@ -212,12 +227,12 @@ class BrowserTabManager:
return result return result
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]: def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.save_pdf(file_path, tab_id) result = browser.save_pdf(file_path, tab_id)
result["message"] = f"Page saved as PDF: {file_path}" result["message"] = f"Page saved as PDF: {file_path}"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to save PDF: {e}") from e raise RuntimeError(f"Failed to save PDF: {e}") from e
@@ -225,12 +240,12 @@ class BrowserTabManager:
return result return result
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]: def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.get_console_logs(tab_id, clear) result = browser.get_console_logs(tab_id, clear)
action_text = "cleared and retrieved" if clear else "retrieved" action_text = "cleared and retrieved" if clear else "retrieved"
logs = result.get("console_logs", []) logs = result.get("console_logs", [])
@@ -247,12 +262,12 @@ class BrowserTabManager:
return result return result
def view_source(self, tab_id: str | None = None) -> dict[str, Any]: def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
result = self.browser_instance.view_source(tab_id) result = browser.view_source(tab_id)
result["message"] = "Page source retrieved" result["message"] = "Page source retrieved"
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to get page source: {e}") from e raise RuntimeError(f"Failed to get page source: {e}") from e
@@ -260,18 +275,18 @@ class BrowserTabManager:
return result return result
def list_tabs(self) -> dict[str, Any]: def list_tabs(self) -> dict[str, Any]:
with self._lock: browser = self._get_agent_browser()
if self.browser_instance is None: if browser is None:
return {"tabs": {}, "total_count": 0, "current_tab": None} return {"tabs": {}, "total_count": 0, "current_tab": None}
try: try:
tab_info = {} tab_info = {}
for tid, tab_page in self.browser_instance.pages.items(): for tid, tab_page in browser.pages.items():
try: try:
tab_info[tid] = { tab_info[tid] = {
"url": tab_page.url, "url": tab_page.url,
"title": "Unknown" if tab_page.is_closed() else "Active", "title": "Unknown" if tab_page.is_closed() else "Active",
"is_current": tid == self.browser_instance.current_page_id, "is_current": tid == browser.current_page_id,
} }
except (AttributeError, RuntimeError): except (AttributeError, RuntimeError):
tab_info[tid] = { tab_info[tid] = {
@@ -283,19 +298,20 @@ class BrowserTabManager:
return { return {
"tabs": tab_info, "tabs": tab_info,
"total_count": len(tab_info), "total_count": len(tab_info),
"current_tab": self.browser_instance.current_page_id, "current_tab": browser.current_page_id,
} }
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to list tabs: {e}") from e raise RuntimeError(f"Failed to list tabs: {e}") from e
def close_browser(self) -> dict[str, Any]: def close_browser(self) -> dict[str, Any]:
agent_id = get_current_agent_id()
with self._lock: with self._lock:
if self.browser_instance is None: browser = self._browsers_by_agent.pop(agent_id, None)
if browser is None:
raise ValueError("Browser not launched") raise ValueError("Browser not launched")
try: try:
self.browser_instance.close() browser.close()
self.browser_instance = None
except (OSError, ValueError, RuntimeError) as e: except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to close browser: {e}") from e raise RuntimeError(f"Failed to close browser: {e}") from e
else: else:
@@ -305,19 +321,34 @@ class BrowserTabManager:
"is_running": False, "is_running": False,
} }
def cleanup_agent(self, agent_id: str) -> None:
with self._lock:
browser = self._browsers_by_agent.pop(agent_id, None)
if browser:
with contextlib.suppress(Exception):
browser.close()
def cleanup_dead_browser(self) -> None: def cleanup_dead_browser(self) -> None:
with self._lock: with self._lock:
if self.browser_instance and not self.browser_instance.is_alive(): dead_agents = []
for agent_id, browser in self._browsers_by_agent.items():
if not browser.is_alive():
dead_agents.append(agent_id)
for agent_id in dead_agents:
browser = self._browsers_by_agent.pop(agent_id)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self.browser_instance.close() browser.close()
self.browser_instance = None
def close_all(self) -> None: def close_all(self) -> None:
with self._lock: with self._lock:
if self.browser_instance: browsers = list(self._browsers_by_agent.values())
self._browsers_by_agent.clear()
for browser in browsers:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self.browser_instance.close() browser.close()
self.browser_instance = None
def _register_cleanup_handlers(self) -> None: def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all) atexit.register(self.close_all)

12
strix/tools/context.py Normal file
View File

@@ -0,0 +1,12 @@
from contextvars import ContextVar
current_agent_id: ContextVar[str] = ContextVar("current_agent_id", default="default")
def get_current_agent_id() -> str:
return current_agent_id.get()
def set_current_agent_id(agent_id: str) -> None:
current_agent_id.set(agent_id)

View File

@@ -5,6 +5,7 @@ from typing import Any
import httpx import httpx
from strix.config import Config from strix.config import Config
from strix.telemetry import posthog
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false": if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
@@ -20,7 +21,8 @@ from .registry import (
) )
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120") _SERVER_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
SANDBOX_EXECUTION_TIMEOUT = _SERVER_TIMEOUT + 30
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10") SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
@@ -82,14 +84,17 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A
response.raise_for_status() response.raise_for_status()
response_data = response.json() response_data = response.json()
if response_data.get("error"): if response_data.get("error"):
posthog.error("tool_execution_error", f"{tool_name}: {response_data['error']}")
raise RuntimeError(f"Sandbox execution error: {response_data['error']}") raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
return response_data.get("result") return response_data.get("result")
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
posthog.error("tool_http_error", f"{tool_name}: HTTP {e.response.status_code}")
if e.response.status_code == 401: if e.response.status_code == 401:
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
except httpx.RequestError as e: except httpx.RequestError as e:
error_type = type(e).__name__ error_type = type(e).__name__
posthog.error("tool_request_error", f"{tool_name}: {error_type}")
raise RuntimeError(f"Request error calling tool server: {error_type}") from e raise RuntimeError(f"Request error calling tool server: {error_type}") from e

View File

@@ -785,6 +785,7 @@ _PROXY_MANAGER: ProxyManager | None = None
def get_proxy_manager() -> ProxyManager: def get_proxy_manager() -> ProxyManager:
global _PROXY_MANAGER # noqa: PLW0603
if _PROXY_MANAGER is None: if _PROXY_MANAGER is None:
return ProxyManager() _PROXY_MANAGER = ProxyManager()
return _PROXY_MANAGER return _PROXY_MANAGER

View File

@@ -3,29 +3,39 @@ import contextlib
import threading import threading
from typing import Any from typing import Any
from strix.tools.context import get_current_agent_id
from .python_instance import PythonInstance from .python_instance import PythonInstance
class PythonSessionManager: class PythonSessionManager:
def __init__(self) -> None: def __init__(self) -> None:
self.sessions: dict[str, PythonInstance] = {} self._sessions_by_agent: dict[str, dict[str, PythonInstance]] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
self.default_session_id = "default" self.default_session_id = "default"
self._register_cleanup_handlers() self._register_cleanup_handlers()
def _get_agent_sessions(self) -> dict[str, PythonInstance]:
agent_id = get_current_agent_id()
with self._lock:
if agent_id not in self._sessions_by_agent:
self._sessions_by_agent[agent_id] = {}
return self._sessions_by_agent[agent_id]
def create_session( def create_session(
self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30 self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30
) -> dict[str, Any]: ) -> dict[str, Any]:
if session_id is None: if session_id is None:
session_id = self.default_session_id session_id = self.default_session_id
sessions = self._get_agent_sessions()
with self._lock: with self._lock:
if session_id in self.sessions: if session_id in sessions:
raise ValueError(f"Python session '{session_id}' already exists") raise ValueError(f"Python session '{session_id}' already exists")
session = PythonInstance(session_id) session = PythonInstance(session_id)
self.sessions[session_id] = session sessions[session_id] = session
if initial_code: if initial_code:
result = session.execute_code(initial_code, timeout) result = session.execute_code(initial_code, timeout)
@@ -49,11 +59,12 @@ class PythonSessionManager:
if not code: if not code:
raise ValueError("No code provided for execution") raise ValueError("No code provided for execution")
sessions = self._get_agent_sessions()
with self._lock: with self._lock:
if session_id not in self.sessions: if session_id not in sessions:
raise ValueError(f"Python session '{session_id}' not found") raise ValueError(f"Python session '{session_id}' not found")
session = self.sessions[session_id] session = sessions[session_id]
result = session.execute_code(code, timeout) result = session.execute_code(code, timeout)
result["message"] = f"Code executed in session '{session_id}'" result["message"] = f"Code executed in session '{session_id}'"
@@ -63,11 +74,12 @@ class PythonSessionManager:
if session_id is None: if session_id is None:
session_id = self.default_session_id session_id = self.default_session_id
sessions = self._get_agent_sessions()
with self._lock: with self._lock:
if session_id not in self.sessions: if session_id not in sessions:
raise ValueError(f"Python session '{session_id}' not found") raise ValueError(f"Python session '{session_id}' not found")
session = self.sessions.pop(session_id) session = sessions.pop(session_id)
session.close() session.close()
return { return {
@@ -77,9 +89,10 @@ class PythonSessionManager:
} }
def list_sessions(self) -> dict[str, Any]: def list_sessions(self) -> dict[str, Any]:
sessions = self._get_agent_sessions()
with self._lock: with self._lock:
session_info = {} session_info = {}
for sid, session in self.sessions.items(): for sid, session in sessions.items():
session_info[sid] = { session_info[sid] = {
"is_running": session.is_running, "is_running": session.is_running,
"is_alive": session.is_alive(), "is_alive": session.is_alive(),
@@ -87,24 +100,35 @@ class PythonSessionManager:
return {"sessions": session_info, "total_count": len(session_info)} return {"sessions": session_info, "total_count": len(session_info)}
def cleanup_agent(self, agent_id: str) -> None:
with self._lock:
sessions = self._sessions_by_agent.pop(agent_id, {})
for session in sessions.values():
with contextlib.suppress(Exception):
session.close()
def cleanup_dead_sessions(self) -> None: def cleanup_dead_sessions(self) -> None:
with self._lock: with self._lock:
for sessions in self._sessions_by_agent.values():
dead_sessions = [] dead_sessions = []
for sid, session in self.sessions.items(): for sid, session in sessions.items():
if not session.is_alive(): if not session.is_alive():
dead_sessions.append(sid) dead_sessions.append(sid)
for sid in dead_sessions: for sid in dead_sessions:
session = self.sessions.pop(sid) session = sessions.pop(sid)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
session.close() session.close()
def close_all_sessions(self) -> None: def close_all_sessions(self) -> None:
with self._lock: with self._lock:
sessions_to_close = list(self.sessions.values()) all_sessions: list[PythonInstance] = []
self.sessions.clear() for sessions in self._sessions_by_agent.values():
all_sessions.extend(sessions.values())
self._sessions_by_agent.clear()
for session in sessions_to_close: for session in all_sessions:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
session.close() session.close()

View File

@@ -3,18 +3,27 @@ import contextlib
import threading import threading
from typing import Any from typing import Any
from strix.tools.context import get_current_agent_id
from .terminal_session import TerminalSession from .terminal_session import TerminalSession
class TerminalManager: class TerminalManager:
def __init__(self) -> None: def __init__(self) -> None:
self.sessions: dict[str, TerminalSession] = {} self._sessions_by_agent: dict[str, dict[str, TerminalSession]] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
self.default_terminal_id = "default" self.default_terminal_id = "default"
self.default_timeout = 30.0 self.default_timeout = 30.0
self._register_cleanup_handlers() self._register_cleanup_handlers()
def _get_agent_sessions(self) -> dict[str, TerminalSession]:
agent_id = get_current_agent_id()
with self._lock:
if agent_id not in self._sessions_by_agent:
self._sessions_by_agent[agent_id] = {}
return self._sessions_by_agent[agent_id]
def execute_command( def execute_command(
self, self,
command: str, command: str,
@@ -62,24 +71,26 @@ class TerminalManager:
} }
def _get_or_create_session(self, terminal_id: str) -> TerminalSession: def _get_or_create_session(self, terminal_id: str) -> TerminalSession:
sessions = self._get_agent_sessions()
with self._lock: with self._lock:
if terminal_id not in self.sessions: if terminal_id not in sessions:
self.sessions[terminal_id] = TerminalSession(terminal_id) sessions[terminal_id] = TerminalSession(terminal_id)
return self.sessions[terminal_id] return sessions[terminal_id]
def close_session(self, terminal_id: str | None = None) -> dict[str, Any]: def close_session(self, terminal_id: str | None = None) -> dict[str, Any]:
if terminal_id is None: if terminal_id is None:
terminal_id = self.default_terminal_id terminal_id = self.default_terminal_id
sessions = self._get_agent_sessions()
with self._lock: with self._lock:
if terminal_id not in self.sessions: if terminal_id not in sessions:
return { return {
"terminal_id": terminal_id, "terminal_id": terminal_id,
"message": f"Terminal '{terminal_id}' not found", "message": f"Terminal '{terminal_id}' not found",
"status": "not_found", "status": "not_found",
} }
session = self.sessions.pop(terminal_id) session = sessions.pop(terminal_id)
try: try:
session.close() session.close()
@@ -97,9 +108,10 @@ class TerminalManager:
} }
def list_sessions(self) -> dict[str, Any]: def list_sessions(self) -> dict[str, Any]:
sessions = self._get_agent_sessions()
with self._lock: with self._lock:
session_info: dict[str, dict[str, Any]] = {} session_info: dict[str, dict[str, Any]] = {}
for tid, session in self.sessions.items(): for tid, session in sessions.items():
session_info[tid] = { session_info[tid] = {
"is_running": session.is_running(), "is_running": session.is_running(),
"working_dir": session.get_working_dir(), "working_dir": session.get_working_dir(),
@@ -107,24 +119,35 @@ class TerminalManager:
return {"sessions": session_info, "total_count": len(session_info)} return {"sessions": session_info, "total_count": len(session_info)}
def cleanup_agent(self, agent_id: str) -> None:
with self._lock:
sessions = self._sessions_by_agent.pop(agent_id, {})
for session in sessions.values():
with contextlib.suppress(Exception):
session.close()
def cleanup_dead_sessions(self) -> None: def cleanup_dead_sessions(self) -> None:
with self._lock: with self._lock:
for sessions in self._sessions_by_agent.values():
dead_sessions: list[str] = [] dead_sessions: list[str] = []
for tid, session in self.sessions.items(): for tid, session in sessions.items():
if not session.is_running(): if not session.is_running():
dead_sessions.append(tid) dead_sessions.append(tid)
for tid in dead_sessions: for tid in dead_sessions:
session = self.sessions.pop(tid) session = sessions.pop(tid)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
session.close() session.close()
def close_all_sessions(self) -> None: def close_all_sessions(self) -> None:
with self._lock: with self._lock:
sessions_to_close = list(self.sessions.values()) all_sessions: list[TerminalSession] = []
self.sessions.clear() for sessions in self._sessions_by_agent.values():
all_sessions.extend(sessions.values())
self._sessions_by_agent.clear()
for session in sessions_to_close: for session in all_sessions:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
session.close() session.close()