From 918a151892747d9f58428f4baf8926537aee1739 Mon Sep 17 00:00:00 2001 From: 0xallam Date: Sat, 17 Jan 2026 21:41:57 -0800 Subject: [PATCH] refactor: simplify tool server to asyncio tasks with per-agent isolation - Replace multiprocessing/threading with single asyncio task per agent - Add task cancellation: new request cancels previous for same agent - Add per-agent state isolation via ContextVar for Terminal, Browser, Python managers - Add posthog telemetry for tool execution errors (timeout, http, sandbox) - Fix proxy manager singleton pattern - Increase client timeout buffer over server timeout - Add context.py to Dockerfile --- containers/Dockerfile | 2 +- strix/runtime/tool_server.py | 210 +++++----------------- strix/tools/browser/tab_manager.py | 211 +++++++++++++---------- strix/tools/context.py | 12 ++ strix/tools/executor.py | 7 +- strix/tools/proxy/proxy_manager.py | 3 +- strix/tools/python/python_manager.py | 62 +++++-- strix/tools/terminal/terminal_manager.py | 59 +++++-- 8 files changed, 271 insertions(+), 295 deletions(-) create mode 100644 strix/tools/context.py diff --git a/containers/Dockerfile b/containers/Dockerfile index 40d9573..4a1b121 100644 --- a/containers/Dockerfile +++ b/containers/Dockerfile @@ -172,7 +172,7 @@ COPY strix/config/ /app/strix/config/ COPY strix/utils/ /app/strix/utils/ COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/ -COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/ +COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py strix/tools/context.py /app/strix/tools/ COPY strix/tools/browser/ /app/strix/tools/browser/ COPY strix/tools/file_edit/ /app/strix/tools/file_edit/ diff --git a/strix/runtime/tool_server.py b/strix/runtime/tool_server.py index b410dc4..ee5fb49 100644 --- a/strix/runtime/tool_server.py +++ b/strix/runtime/tool_server.py @@ -2,16 +2,10 @@ from __future__ import annotations import argparse import asyncio -import contextlib -import logging import os -import queue as stdlib_queue import signal import sys -import threading -from multiprocessing import Process, Queue from typing import Any -from uuid import uuid4 import uvicorn from fastapi import Depends, FastAPI, HTTPException, status @@ -40,13 +34,9 @@ REQUEST_TIMEOUT = args.timeout app = FastAPI() security = HTTPBearer() - security_dependency = Depends(security) -agent_processes: dict[str, dict[str, Any]] = {} -agent_queues: dict[str, dict[str, Queue[Any]]] = {} -pending_responses: dict[str, dict[str, asyncio.Future[Any]]] = {} -agent_listeners: dict[str, dict[str, Any]] = {} +agent_tasks: dict[str, asyncio.Task[Any]] = {} def verify_token(credentials: HTTPAuthorizationCredentials) -> str: @@ -78,107 +68,19 @@ class ToolExecutionResponse(BaseModel): error: str | None = None -def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queue[Any]) -> None: - null_handler = logging.NullHandler() - - root_logger = logging.getLogger() - root_logger.handlers = [null_handler] - root_logger.setLevel(logging.CRITICAL) - - from concurrent.futures import ThreadPoolExecutor - - from strix.tools.argument_parser import ArgumentConversionError, convert_arguments +async def _run_tool(agent_id: str, tool_name: str, kwargs: dict[str, Any]) -> Any: + from strix.tools.argument_parser import convert_arguments + from strix.tools.context import set_current_agent_id from strix.tools.registry import get_tool_by_name - def _execute_request(request: dict[str, Any]) -> None: - request_id = request.get("request_id", "") - tool_name = request["tool_name"] - kwargs = request["kwargs"] + set_current_agent_id(agent_id) - try: - tool_func = get_tool_by_name(tool_name) - if not tool_func: - response_queue.put( - {"request_id": request_id, "error": f"Tool '{tool_name}' not found"} - ) - return + tool_func = get_tool_by_name(tool_name) + if not tool_func: + raise ValueError(f"Tool '{tool_name}' not found") - converted_kwargs = convert_arguments(tool_func, kwargs) - result = tool_func(**converted_kwargs) - - response_queue.put({"request_id": request_id, "result": result}) - - except (ArgumentConversionError, ValidationError) as e: - response_queue.put({"request_id": request_id, "error": f"Invalid arguments: {e}"}) - except (RuntimeError, ValueError, ImportError) as e: - response_queue.put({"request_id": request_id, "error": f"Tool execution error: {e}"}) - except Exception as e: # noqa: BLE001 - response_queue.put({"request_id": request_id, "error": f"Unexpected error: {e}"}) - - with ThreadPoolExecutor() as executor: - while True: - request = None - try: - request = request_queue.get() - - if request is None: - break - - executor.submit(_execute_request, request) - - except (RuntimeError, ValueError, ImportError) as e: - req_id = request.get("request_id", "") if request else "" - response_queue.put({"request_id": req_id, "error": f"Worker error: {e}"}) - - -def _ensure_response_listener(agent_id: str, response_queue: Queue[Any]) -> None: - if agent_id in agent_listeners: - return - - stop_event = threading.Event() - loop = asyncio.get_running_loop() - - def _listener() -> None: - while not stop_event.is_set(): - try: - item = response_queue.get(timeout=0.5) - except stdlib_queue.Empty: - continue - except (BrokenPipeError, EOFError): - break - - request_id = item.get("request_id") - if not request_id or agent_id not in pending_responses: - continue - - future = pending_responses[agent_id].pop(request_id, None) - if future and not future.done(): - with contextlib.suppress(RuntimeError): - loop.call_soon_threadsafe(future.set_result, item) - - listener_thread = threading.Thread(target=_listener, daemon=True) - listener_thread.start() - - agent_listeners[agent_id] = {"thread": listener_thread, "stop_event": stop_event} - - -def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]: - if agent_id not in agent_processes: - request_queue: Queue[Any] = Queue() - response_queue: Queue[Any] = Queue() - - process = Process( - target=agent_worker, args=(agent_id, request_queue, response_queue), daemon=True - ) - process.start() - - agent_processes[agent_id] = {"process": process, "pid": process.pid} - agent_queues[agent_id] = {"request": request_queue, "response": response_queue} - pending_responses[agent_id] = {} - - _ensure_response_listener(agent_id, response_queue) - - return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"] + converted_kwargs = convert_arguments(tool_func, kwargs) + return await asyncio.to_thread(tool_func, **converted_kwargs) @app.post("/execute", response_model=ToolExecutionResponse) @@ -187,33 +89,42 @@ async def execute_tool( ) -> ToolExecutionResponse: verify_token(credentials) - request_queue, _response_queue = ensure_agent_process(request.agent_id) + agent_id = request.agent_id - loop = asyncio.get_running_loop() - req_id = uuid4().hex - future: asyncio.Future[Any] = loop.create_future() - pending_responses[request.agent_id][req_id] = future + if agent_id in agent_tasks: + old_task = agent_tasks[agent_id] + if not old_task.done(): + old_task.cancel() - request_queue.put( - { - "request_id": req_id, - "tool_name": request.tool_name, - "kwargs": request.kwargs, - } + task = asyncio.create_task( + asyncio.wait_for( + _run_tool(agent_id, request.tool_name, request.kwargs), timeout=REQUEST_TIMEOUT + ) ) + agent_tasks[agent_id] = task try: - response = await asyncio.wait_for(future, timeout=REQUEST_TIMEOUT) + result = await task + return ToolExecutionResponse(result=result) - if "error" in response: - return ToolExecutionResponse(error=response["error"]) - return ToolExecutionResponse(result=response.get("result")) + except asyncio.CancelledError: + return ToolExecutionResponse(error="Cancelled by newer request") except TimeoutError: - pending_responses[request.agent_id].pop(req_id, None) - return ToolExecutionResponse(error=f"Request timed out after {REQUEST_TIMEOUT} seconds") - except (RuntimeError, ValueError, OSError) as e: - return ToolExecutionResponse(error=f"Worker error: {e}") + return ToolExecutionResponse(error=f"Tool timed out after {REQUEST_TIMEOUT}s") + + except ValidationError as e: + return ToolExecutionResponse(error=f"Invalid arguments: {e}") + + except (ValueError, RuntimeError, ImportError) as e: + return ToolExecutionResponse(error=f"Tool execution error: {e}") + + except Exception as e: # noqa: BLE001 + return ToolExecutionResponse(error=f"Unexpected error: {e}") + + finally: + if agent_tasks.get(agent_id) is task: + del agent_tasks[agent_id] @app.post("/register_agent") @@ -221,8 +132,6 @@ async def register_agent( agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency ) -> dict[str, str]: verify_token(credentials) - - ensure_agent_process(agent_id) return {"status": "registered", "agent_id": agent_id} @@ -233,42 +142,16 @@ async def health_check() -> dict[str, Any]: "sandbox_mode": str(SANDBOX_MODE), "environment": "sandbox" if SANDBOX_MODE else "main", "auth_configured": "true" if EXPECTED_TOKEN else "false", - "active_agents": len(agent_processes), - "agents": list(agent_processes.keys()), + "active_agents": len(agent_tasks), + "agents": list(agent_tasks.keys()), } -def cleanup_all_agents() -> None: - for agent_id in list(agent_processes.keys()): - try: - if agent_id in agent_listeners: - agent_listeners[agent_id]["stop_event"].set() - - agent_queues[agent_id]["request"].put(None) - process = agent_processes[agent_id]["process"] - - process.join(timeout=1) - - if process.is_alive(): - process.terminate() - process.join(timeout=1) - - if process.is_alive(): - process.kill() - - if agent_id in agent_listeners: - listener_thread = agent_listeners[agent_id]["thread"] - listener_thread.join(timeout=0.5) - - except (BrokenPipeError, EOFError, OSError): - pass - except (RuntimeError, ValueError) as e: - logging.getLogger(__name__).debug(f"Error during agent cleanup: {e}") - - def signal_handler(_signum: int, _frame: Any) -> None: - signal.signal(signal.SIGPIPE, signal.SIG_IGN) if hasattr(signal, "SIGPIPE") else None - cleanup_all_agents() + if hasattr(signal, "SIGPIPE"): + signal.signal(signal.SIGPIPE, signal.SIG_IGN) + for task in agent_tasks.values(): + task.cancel() sys.exit(0) @@ -279,7 +162,4 @@ signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) if __name__ == "__main__": - try: - uvicorn.run(app, host=args.host, port=args.port, log_level="info") - finally: - cleanup_all_agents() + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/strix/tools/browser/tab_manager.py b/strix/tools/browser/tab_manager.py index a77dda2..b40eecf 100644 --- a/strix/tools/browser/tab_manager.py +++ b/strix/tools/browser/tab_manager.py @@ -3,39 +3,54 @@ import contextlib import threading from typing import Any +from strix.tools.context import get_current_agent_id + from .browser_instance import BrowserInstance class BrowserTabManager: def __init__(self) -> None: - self.browser_instance: BrowserInstance | None = None + self._browsers_by_agent: dict[str, BrowserInstance] = {} self._lock = threading.Lock() self._register_cleanup_handlers() + def _get_agent_browser(self) -> BrowserInstance | None: + agent_id = get_current_agent_id() + with self._lock: + return self._browsers_by_agent.get(agent_id) + + def _set_agent_browser(self, browser: BrowserInstance | None) -> None: + agent_id = get_current_agent_id() + with self._lock: + if browser is None: + self._browsers_by_agent.pop(agent_id, None) + else: + self._browsers_by_agent[agent_id] = browser + def launch_browser(self, url: str | None = None) -> dict[str, Any]: with self._lock: - if self.browser_instance is not None: + agent_id = get_current_agent_id() + if agent_id in self._browsers_by_agent: raise ValueError("Browser is already launched") try: - self.browser_instance = BrowserInstance() - result = self.browser_instance.launch(url) + browser = BrowserInstance() + result = browser.launch(url) + self._browsers_by_agent[agent_id] = browser result["message"] = "Browser launched successfully" except (OSError, ValueError, RuntimeError) as e: - if self.browser_instance: - self.browser_instance = None raise RuntimeError(f"Failed to launch browser: {e}") from e else: return result def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.goto(url, tab_id) + result = browser.goto(url, tab_id) result["message"] = f"Navigated to {url}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to navigate to URL: {e}") from e @@ -43,12 +58,12 @@ class BrowserTabManager: return result def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.click(coordinate, tab_id) + result = browser.click(coordinate, tab_id) result["message"] = f"Clicked at {coordinate}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to click: {e}") from e @@ -56,12 +71,12 @@ class BrowserTabManager: return result def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.type_text(text, tab_id) + result = browser.type_text(text, tab_id) result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to type text: {e}") from e @@ -69,12 +84,12 @@ class BrowserTabManager: return result def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.scroll(direction, tab_id) + result = browser.scroll(direction, tab_id) result["message"] = f"Scrolled {direction}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to scroll: {e}") from e @@ -82,12 +97,12 @@ class BrowserTabManager: return result def back(self, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.back(tab_id) + result = browser.back(tab_id) result["message"] = "Navigated back" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to go back: {e}") from e @@ -95,12 +110,12 @@ class BrowserTabManager: return result def forward(self, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.forward(tab_id) + result = browser.forward(tab_id) result["message"] = "Navigated forward" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to go forward: {e}") from e @@ -108,12 +123,12 @@ class BrowserTabManager: return result def new_tab(self, url: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.new_tab(url) + result = browser.new_tab(url) result["message"] = f"Created new tab {result.get('tab_id', '')}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to create new tab: {e}") from e @@ -121,12 +136,12 @@ class BrowserTabManager: return result def switch_tab(self, tab_id: str) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.switch_tab(tab_id) + result = browser.switch_tab(tab_id) result["message"] = f"Switched to tab {tab_id}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to switch tab: {e}") from e @@ -134,12 +149,12 @@ class BrowserTabManager: return result def close_tab(self, tab_id: str) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.close_tab(tab_id) + result = browser.close_tab(tab_id) result["message"] = f"Closed tab {tab_id}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to close tab: {e}") from e @@ -147,12 +162,12 @@ class BrowserTabManager: return result def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.wait(duration, tab_id) + result = browser.wait(duration, tab_id) result["message"] = f"Waited {duration}s" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to wait: {e}") from e @@ -160,12 +175,12 @@ class BrowserTabManager: return result def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.execute_js(js_code, tab_id) + result = browser.execute_js(js_code, tab_id) result["message"] = "JavaScript executed successfully" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to execute JavaScript: {e}") from e @@ -173,12 +188,12 @@ class BrowserTabManager: return result def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.double_click(coordinate, tab_id) + result = browser.double_click(coordinate, tab_id) result["message"] = f"Double clicked at {coordinate}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to double click: {e}") from e @@ -186,12 +201,12 @@ class BrowserTabManager: return result def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.hover(coordinate, tab_id) + result = browser.hover(coordinate, tab_id) result["message"] = f"Hovered at {coordinate}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to hover: {e}") from e @@ -199,12 +214,12 @@ class BrowserTabManager: return result def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.press_key(key, tab_id) + result = browser.press_key(key, tab_id) result["message"] = f"Pressed key {key}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to press key: {e}") from e @@ -212,12 +227,12 @@ class BrowserTabManager: return result def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.save_pdf(file_path, tab_id) + result = browser.save_pdf(file_path, tab_id) result["message"] = f"Page saved as PDF: {file_path}" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to save PDF: {e}") from e @@ -225,12 +240,12 @@ class BrowserTabManager: return result def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.get_console_logs(tab_id, clear) + result = browser.get_console_logs(tab_id, clear) action_text = "cleared and retrieved" if clear else "retrieved" logs = result.get("console_logs", []) @@ -247,12 +262,12 @@ class BrowserTabManager: return result def view_source(self, tab_id: str | None = None) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - raise ValueError("Browser not launched") + browser = self._get_agent_browser() + if browser is None: + raise ValueError("Browser not launched") try: - result = self.browser_instance.view_source(tab_id) + result = browser.view_source(tab_id) result["message"] = "Page source retrieved" except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to get page source: {e}") from e @@ -260,18 +275,18 @@ class BrowserTabManager: return result def list_tabs(self) -> dict[str, Any]: - with self._lock: - if self.browser_instance is None: - return {"tabs": {}, "total_count": 0, "current_tab": None} + browser = self._get_agent_browser() + if browser is None: + return {"tabs": {}, "total_count": 0, "current_tab": None} try: tab_info = {} - for tid, tab_page in self.browser_instance.pages.items(): + for tid, tab_page in browser.pages.items(): try: tab_info[tid] = { "url": tab_page.url, "title": "Unknown" if tab_page.is_closed() else "Active", - "is_current": tid == self.browser_instance.current_page_id, + "is_current": tid == browser.current_page_id, } except (AttributeError, RuntimeError): tab_info[tid] = { @@ -283,19 +298,20 @@ class BrowserTabManager: return { "tabs": tab_info, "total_count": len(tab_info), - "current_tab": self.browser_instance.current_page_id, + "current_tab": browser.current_page_id, } except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to list tabs: {e}") from e def close_browser(self) -> dict[str, Any]: + agent_id = get_current_agent_id() with self._lock: - if self.browser_instance is None: + browser = self._browsers_by_agent.pop(agent_id, None) + if browser is None: raise ValueError("Browser not launched") try: - self.browser_instance.close() - self.browser_instance = None + browser.close() except (OSError, ValueError, RuntimeError) as e: raise RuntimeError(f"Failed to close browser: {e}") from e else: @@ -305,19 +321,34 @@ class BrowserTabManager: "is_running": False, } + def cleanup_agent(self, agent_id: str) -> None: + with self._lock: + browser = self._browsers_by_agent.pop(agent_id, None) + + if browser: + with contextlib.suppress(Exception): + browser.close() + def cleanup_dead_browser(self) -> None: with self._lock: - if self.browser_instance and not self.browser_instance.is_alive(): + dead_agents = [] + for agent_id, browser in self._browsers_by_agent.items(): + if not browser.is_alive(): + dead_agents.append(agent_id) + + for agent_id in dead_agents: + browser = self._browsers_by_agent.pop(agent_id) with contextlib.suppress(Exception): - self.browser_instance.close() - self.browser_instance = None + browser.close() def close_all(self) -> None: with self._lock: - if self.browser_instance: - with contextlib.suppress(Exception): - self.browser_instance.close() - self.browser_instance = None + browsers = list(self._browsers_by_agent.values()) + self._browsers_by_agent.clear() + + for browser in browsers: + with contextlib.suppress(Exception): + browser.close() def _register_cleanup_handlers(self) -> None: atexit.register(self.close_all) diff --git a/strix/tools/context.py b/strix/tools/context.py new file mode 100644 index 0000000..e61f447 --- /dev/null +++ b/strix/tools/context.py @@ -0,0 +1,12 @@ +from contextvars import ContextVar + + +current_agent_id: ContextVar[str] = ContextVar("current_agent_id", default="default") + + +def get_current_agent_id() -> str: + return current_agent_id.get() + + +def set_current_agent_id(agent_id: str) -> None: + current_agent_id.set(agent_id) diff --git a/strix/tools/executor.py b/strix/tools/executor.py index eb34b38..1c24087 100644 --- a/strix/tools/executor.py +++ b/strix/tools/executor.py @@ -5,6 +5,7 @@ from typing import Any import httpx from strix.config import Config +from strix.telemetry import posthog if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false": @@ -20,7 +21,8 @@ from .registry import ( ) -SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120") +_SERVER_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120") +SANDBOX_EXECUTION_TIMEOUT = _SERVER_TIMEOUT + 30 SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10") @@ -82,14 +84,17 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A response.raise_for_status() response_data = response.json() if response_data.get("error"): + posthog.error("tool_execution_error", f"{tool_name}: {response_data['error']}") raise RuntimeError(f"Sandbox execution error: {response_data['error']}") return response_data.get("result") except httpx.HTTPStatusError as e: + posthog.error("tool_http_error", f"{tool_name}: HTTP {e.response.status_code}") if e.response.status_code == 401: raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e except httpx.RequestError as e: error_type = type(e).__name__ + posthog.error("tool_request_error", f"{tool_name}: {error_type}") raise RuntimeError(f"Request error calling tool server: {error_type}") from e diff --git a/strix/tools/proxy/proxy_manager.py b/strix/tools/proxy/proxy_manager.py index 7679be6..dafd3f1 100644 --- a/strix/tools/proxy/proxy_manager.py +++ b/strix/tools/proxy/proxy_manager.py @@ -785,6 +785,7 @@ _PROXY_MANAGER: ProxyManager | None = None def get_proxy_manager() -> ProxyManager: + global _PROXY_MANAGER # noqa: PLW0603 if _PROXY_MANAGER is None: - return ProxyManager() + _PROXY_MANAGER = ProxyManager() return _PROXY_MANAGER diff --git a/strix/tools/python/python_manager.py b/strix/tools/python/python_manager.py index 73376ab..4d80e1e 100644 --- a/strix/tools/python/python_manager.py +++ b/strix/tools/python/python_manager.py @@ -3,29 +3,39 @@ import contextlib import threading from typing import Any +from strix.tools.context import get_current_agent_id + from .python_instance import PythonInstance class PythonSessionManager: def __init__(self) -> None: - self.sessions: dict[str, PythonInstance] = {} + self._sessions_by_agent: dict[str, dict[str, PythonInstance]] = {} self._lock = threading.Lock() self.default_session_id = "default" self._register_cleanup_handlers() + def _get_agent_sessions(self) -> dict[str, PythonInstance]: + agent_id = get_current_agent_id() + with self._lock: + if agent_id not in self._sessions_by_agent: + self._sessions_by_agent[agent_id] = {} + return self._sessions_by_agent[agent_id] + def create_session( self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30 ) -> dict[str, Any]: if session_id is None: session_id = self.default_session_id + sessions = self._get_agent_sessions() with self._lock: - if session_id in self.sessions: + if session_id in sessions: raise ValueError(f"Python session '{session_id}' already exists") session = PythonInstance(session_id) - self.sessions[session_id] = session + sessions[session_id] = session if initial_code: result = session.execute_code(initial_code, timeout) @@ -49,11 +59,12 @@ class PythonSessionManager: if not code: raise ValueError("No code provided for execution") + sessions = self._get_agent_sessions() with self._lock: - if session_id not in self.sessions: + if session_id not in sessions: raise ValueError(f"Python session '{session_id}' not found") - session = self.sessions[session_id] + session = sessions[session_id] result = session.execute_code(code, timeout) result["message"] = f"Code executed in session '{session_id}'" @@ -63,11 +74,12 @@ class PythonSessionManager: if session_id is None: session_id = self.default_session_id + sessions = self._get_agent_sessions() with self._lock: - if session_id not in self.sessions: + if session_id not in sessions: raise ValueError(f"Python session '{session_id}' not found") - session = self.sessions.pop(session_id) + session = sessions.pop(session_id) session.close() return { @@ -77,9 +89,10 @@ class PythonSessionManager: } def list_sessions(self) -> dict[str, Any]: + sessions = self._get_agent_sessions() with self._lock: session_info = {} - for sid, session in self.sessions.items(): + for sid, session in sessions.items(): session_info[sid] = { "is_running": session.is_running, "is_alive": session.is_alive(), @@ -87,24 +100,35 @@ class PythonSessionManager: return {"sessions": session_info, "total_count": len(session_info)} + def cleanup_agent(self, agent_id: str) -> None: + with self._lock: + sessions = self._sessions_by_agent.pop(agent_id, {}) + + for session in sessions.values(): + with contextlib.suppress(Exception): + session.close() + def cleanup_dead_sessions(self) -> None: with self._lock: - dead_sessions = [] - for sid, session in self.sessions.items(): - if not session.is_alive(): - dead_sessions.append(sid) + for sessions in self._sessions_by_agent.values(): + dead_sessions = [] + for sid, session in sessions.items(): + if not session.is_alive(): + dead_sessions.append(sid) - for sid in dead_sessions: - session = self.sessions.pop(sid) - with contextlib.suppress(Exception): - session.close() + for sid in dead_sessions: + session = sessions.pop(sid) + with contextlib.suppress(Exception): + session.close() def close_all_sessions(self) -> None: with self._lock: - sessions_to_close = list(self.sessions.values()) - self.sessions.clear() + all_sessions: list[PythonInstance] = [] + for sessions in self._sessions_by_agent.values(): + all_sessions.extend(sessions.values()) + self._sessions_by_agent.clear() - for session in sessions_to_close: + for session in all_sessions: with contextlib.suppress(Exception): session.close() diff --git a/strix/tools/terminal/terminal_manager.py b/strix/tools/terminal/terminal_manager.py index 320dd18..8192c07 100644 --- a/strix/tools/terminal/terminal_manager.py +++ b/strix/tools/terminal/terminal_manager.py @@ -3,18 +3,27 @@ import contextlib import threading from typing import Any +from strix.tools.context import get_current_agent_id + from .terminal_session import TerminalSession class TerminalManager: def __init__(self) -> None: - self.sessions: dict[str, TerminalSession] = {} + self._sessions_by_agent: dict[str, dict[str, TerminalSession]] = {} self._lock = threading.Lock() self.default_terminal_id = "default" self.default_timeout = 30.0 self._register_cleanup_handlers() + def _get_agent_sessions(self) -> dict[str, TerminalSession]: + agent_id = get_current_agent_id() + with self._lock: + if agent_id not in self._sessions_by_agent: + self._sessions_by_agent[agent_id] = {} + return self._sessions_by_agent[agent_id] + def execute_command( self, command: str, @@ -62,24 +71,26 @@ class TerminalManager: } def _get_or_create_session(self, terminal_id: str) -> TerminalSession: + sessions = self._get_agent_sessions() with self._lock: - if terminal_id not in self.sessions: - self.sessions[terminal_id] = TerminalSession(terminal_id) - return self.sessions[terminal_id] + if terminal_id not in sessions: + sessions[terminal_id] = TerminalSession(terminal_id) + return sessions[terminal_id] def close_session(self, terminal_id: str | None = None) -> dict[str, Any]: if terminal_id is None: terminal_id = self.default_terminal_id + sessions = self._get_agent_sessions() with self._lock: - if terminal_id not in self.sessions: + if terminal_id not in sessions: return { "terminal_id": terminal_id, "message": f"Terminal '{terminal_id}' not found", "status": "not_found", } - session = self.sessions.pop(terminal_id) + session = sessions.pop(terminal_id) try: session.close() @@ -97,9 +108,10 @@ class TerminalManager: } def list_sessions(self) -> dict[str, Any]: + sessions = self._get_agent_sessions() with self._lock: session_info: dict[str, dict[str, Any]] = {} - for tid, session in self.sessions.items(): + for tid, session in sessions.items(): session_info[tid] = { "is_running": session.is_running(), "working_dir": session.get_working_dir(), @@ -107,24 +119,35 @@ class TerminalManager: return {"sessions": session_info, "total_count": len(session_info)} + def cleanup_agent(self, agent_id: str) -> None: + with self._lock: + sessions = self._sessions_by_agent.pop(agent_id, {}) + + for session in sessions.values(): + with contextlib.suppress(Exception): + session.close() + def cleanup_dead_sessions(self) -> None: with self._lock: - dead_sessions: list[str] = [] - for tid, session in self.sessions.items(): - if not session.is_running(): - dead_sessions.append(tid) + for sessions in self._sessions_by_agent.values(): + dead_sessions: list[str] = [] + for tid, session in sessions.items(): + if not session.is_running(): + dead_sessions.append(tid) - for tid in dead_sessions: - session = self.sessions.pop(tid) - with contextlib.suppress(Exception): - session.close() + for tid in dead_sessions: + session = sessions.pop(tid) + with contextlib.suppress(Exception): + session.close() def close_all_sessions(self) -> None: with self._lock: - sessions_to_close = list(self.sessions.values()) - self.sessions.clear() + all_sessions: list[TerminalSession] = [] + for sessions in self._sessions_by_agent.values(): + all_sessions.extend(sessions.values()) + self._sessions_by_agent.clear() - for session in sessions_to_close: + for session in all_sessions: with contextlib.suppress(Exception): session.close()