refactor: simplify tool server to asyncio tasks with per-agent isolation

- Replace multiprocessing/threading with single asyncio task per agent
- Add task cancellation: new request cancels previous for same agent
- Add per-agent state isolation via ContextVar for Terminal, Browser, Python managers
- Add posthog telemetry for tool execution errors (timeout, http, sandbox)
- Fix proxy manager singleton pattern
- Increase client timeout buffer over server timeout
- Add context.py to Dockerfile
This commit is contained in:
0xallam
2026-01-17 21:41:57 -08:00
committed by Ahmed Allam
parent a80ecac7bd
commit 918a151892
8 changed files with 271 additions and 295 deletions

View File

@@ -3,39 +3,54 @@ import contextlib
import threading
from typing import Any
from strix.tools.context import get_current_agent_id
from .browser_instance import BrowserInstance
class BrowserTabManager:
def __init__(self) -> None:
self.browser_instance: BrowserInstance | None = None
self._browsers_by_agent: dict[str, BrowserInstance] = {}
self._lock = threading.Lock()
self._register_cleanup_handlers()
def _get_agent_browser(self) -> BrowserInstance | None:
agent_id = get_current_agent_id()
with self._lock:
return self._browsers_by_agent.get(agent_id)
def _set_agent_browser(self, browser: BrowserInstance | None) -> None:
agent_id = get_current_agent_id()
with self._lock:
if browser is None:
self._browsers_by_agent.pop(agent_id, None)
else:
self._browsers_by_agent[agent_id] = browser
def launch_browser(self, url: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is not None:
agent_id = get_current_agent_id()
if agent_id in self._browsers_by_agent:
raise ValueError("Browser is already launched")
try:
self.browser_instance = BrowserInstance()
result = self.browser_instance.launch(url)
browser = BrowserInstance()
result = browser.launch(url)
self._browsers_by_agent[agent_id] = browser
result["message"] = "Browser launched successfully"
except (OSError, ValueError, RuntimeError) as e:
if self.browser_instance:
self.browser_instance = None
raise RuntimeError(f"Failed to launch browser: {e}") from e
else:
return result
def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.goto(url, tab_id)
result = browser.goto(url, tab_id)
result["message"] = f"Navigated to {url}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to navigate to URL: {e}") from e
@@ -43,12 +58,12 @@ class BrowserTabManager:
return result
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.click(coordinate, tab_id)
result = browser.click(coordinate, tab_id)
result["message"] = f"Clicked at {coordinate}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to click: {e}") from e
@@ -56,12 +71,12 @@ class BrowserTabManager:
return result
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.type_text(text, tab_id)
result = browser.type_text(text, tab_id)
result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to type text: {e}") from e
@@ -69,12 +84,12 @@ class BrowserTabManager:
return result
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.scroll(direction, tab_id)
result = browser.scroll(direction, tab_id)
result["message"] = f"Scrolled {direction}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to scroll: {e}") from e
@@ -82,12 +97,12 @@ class BrowserTabManager:
return result
def back(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.back(tab_id)
result = browser.back(tab_id)
result["message"] = "Navigated back"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to go back: {e}") from e
@@ -95,12 +110,12 @@ class BrowserTabManager:
return result
def forward(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.forward(tab_id)
result = browser.forward(tab_id)
result["message"] = "Navigated forward"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to go forward: {e}") from e
@@ -108,12 +123,12 @@ class BrowserTabManager:
return result
def new_tab(self, url: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.new_tab(url)
result = browser.new_tab(url)
result["message"] = f"Created new tab {result.get('tab_id', '')}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to create new tab: {e}") from e
@@ -121,12 +136,12 @@ class BrowserTabManager:
return result
def switch_tab(self, tab_id: str) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.switch_tab(tab_id)
result = browser.switch_tab(tab_id)
result["message"] = f"Switched to tab {tab_id}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to switch tab: {e}") from e
@@ -134,12 +149,12 @@ class BrowserTabManager:
return result
def close_tab(self, tab_id: str) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.close_tab(tab_id)
result = browser.close_tab(tab_id)
result["message"] = f"Closed tab {tab_id}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to close tab: {e}") from e
@@ -147,12 +162,12 @@ class BrowserTabManager:
return result
def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.wait(duration, tab_id)
result = browser.wait(duration, tab_id)
result["message"] = f"Waited {duration}s"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to wait: {e}") from e
@@ -160,12 +175,12 @@ class BrowserTabManager:
return result
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.execute_js(js_code, tab_id)
result = browser.execute_js(js_code, tab_id)
result["message"] = "JavaScript executed successfully"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to execute JavaScript: {e}") from e
@@ -173,12 +188,12 @@ class BrowserTabManager:
return result
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.double_click(coordinate, tab_id)
result = browser.double_click(coordinate, tab_id)
result["message"] = f"Double clicked at {coordinate}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to double click: {e}") from e
@@ -186,12 +201,12 @@ class BrowserTabManager:
return result
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.hover(coordinate, tab_id)
result = browser.hover(coordinate, tab_id)
result["message"] = f"Hovered at {coordinate}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to hover: {e}") from e
@@ -199,12 +214,12 @@ class BrowserTabManager:
return result
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.press_key(key, tab_id)
result = browser.press_key(key, tab_id)
result["message"] = f"Pressed key {key}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to press key: {e}") from e
@@ -212,12 +227,12 @@ class BrowserTabManager:
return result
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.save_pdf(file_path, tab_id)
result = browser.save_pdf(file_path, tab_id)
result["message"] = f"Page saved as PDF: {file_path}"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to save PDF: {e}") from e
@@ -225,12 +240,12 @@ class BrowserTabManager:
return result
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.get_console_logs(tab_id, clear)
result = browser.get_console_logs(tab_id, clear)
action_text = "cleared and retrieved" if clear else "retrieved"
logs = result.get("console_logs", [])
@@ -247,12 +262,12 @@ class BrowserTabManager:
return result
def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
raise ValueError("Browser not launched")
browser = self._get_agent_browser()
if browser is None:
raise ValueError("Browser not launched")
try:
result = self.browser_instance.view_source(tab_id)
result = browser.view_source(tab_id)
result["message"] = "Page source retrieved"
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to get page source: {e}") from e
@@ -260,18 +275,18 @@ class BrowserTabManager:
return result
def list_tabs(self) -> dict[str, Any]:
with self._lock:
if self.browser_instance is None:
return {"tabs": {}, "total_count": 0, "current_tab": None}
browser = self._get_agent_browser()
if browser is None:
return {"tabs": {}, "total_count": 0, "current_tab": None}
try:
tab_info = {}
for tid, tab_page in self.browser_instance.pages.items():
for tid, tab_page in browser.pages.items():
try:
tab_info[tid] = {
"url": tab_page.url,
"title": "Unknown" if tab_page.is_closed() else "Active",
"is_current": tid == self.browser_instance.current_page_id,
"is_current": tid == browser.current_page_id,
}
except (AttributeError, RuntimeError):
tab_info[tid] = {
@@ -283,19 +298,20 @@ class BrowserTabManager:
return {
"tabs": tab_info,
"total_count": len(tab_info),
"current_tab": self.browser_instance.current_page_id,
"current_tab": browser.current_page_id,
}
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to list tabs: {e}") from e
def close_browser(self) -> dict[str, Any]:
agent_id = get_current_agent_id()
with self._lock:
if self.browser_instance is None:
browser = self._browsers_by_agent.pop(agent_id, None)
if browser is None:
raise ValueError("Browser not launched")
try:
self.browser_instance.close()
self.browser_instance = None
browser.close()
except (OSError, ValueError, RuntimeError) as e:
raise RuntimeError(f"Failed to close browser: {e}") from e
else:
@@ -305,19 +321,34 @@ class BrowserTabManager:
"is_running": False,
}
def cleanup_agent(self, agent_id: str) -> None:
with self._lock:
browser = self._browsers_by_agent.pop(agent_id, None)
if browser:
with contextlib.suppress(Exception):
browser.close()
def cleanup_dead_browser(self) -> None:
with self._lock:
if self.browser_instance and not self.browser_instance.is_alive():
dead_agents = []
for agent_id, browser in self._browsers_by_agent.items():
if not browser.is_alive():
dead_agents.append(agent_id)
for agent_id in dead_agents:
browser = self._browsers_by_agent.pop(agent_id)
with contextlib.suppress(Exception):
self.browser_instance.close()
self.browser_instance = None
browser.close()
def close_all(self) -> None:
with self._lock:
if self.browser_instance:
with contextlib.suppress(Exception):
self.browser_instance.close()
self.browser_instance = None
browsers = list(self._browsers_by_agent.values())
self._browsers_by_agent.clear()
for browser in browsers:
with contextlib.suppress(Exception):
browser.close()
def _register_cleanup_handlers(self) -> None:
atexit.register(self.close_all)

12
strix/tools/context.py Normal file
View File

@@ -0,0 +1,12 @@
from contextvars import ContextVar
current_agent_id: ContextVar[str] = ContextVar("current_agent_id", default="default")
def get_current_agent_id() -> str:
return current_agent_id.get()
def set_current_agent_id(agent_id: str) -> None:
current_agent_id.set(agent_id)

View File

@@ -5,6 +5,7 @@ from typing import Any
import httpx
from strix.config import Config
from strix.telemetry import posthog
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
@@ -20,7 +21,8 @@ from .registry import (
)
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
_SERVER_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
SANDBOX_EXECUTION_TIMEOUT = _SERVER_TIMEOUT + 30
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
@@ -82,14 +84,17 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A
response.raise_for_status()
response_data = response.json()
if response_data.get("error"):
posthog.error("tool_execution_error", f"{tool_name}: {response_data['error']}")
raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
return response_data.get("result")
except httpx.HTTPStatusError as e:
posthog.error("tool_http_error", f"{tool_name}: HTTP {e.response.status_code}")
if e.response.status_code == 401:
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
except httpx.RequestError as e:
error_type = type(e).__name__
posthog.error("tool_request_error", f"{tool_name}: {error_type}")
raise RuntimeError(f"Request error calling tool server: {error_type}") from e

View File

@@ -785,6 +785,7 @@ _PROXY_MANAGER: ProxyManager | None = None
def get_proxy_manager() -> ProxyManager:
global _PROXY_MANAGER # noqa: PLW0603
if _PROXY_MANAGER is None:
return ProxyManager()
_PROXY_MANAGER = ProxyManager()
return _PROXY_MANAGER

View File

@@ -3,29 +3,39 @@ import contextlib
import threading
from typing import Any
from strix.tools.context import get_current_agent_id
from .python_instance import PythonInstance
class PythonSessionManager:
def __init__(self) -> None:
self.sessions: dict[str, PythonInstance] = {}
self._sessions_by_agent: dict[str, dict[str, PythonInstance]] = {}
self._lock = threading.Lock()
self.default_session_id = "default"
self._register_cleanup_handlers()
def _get_agent_sessions(self) -> dict[str, PythonInstance]:
agent_id = get_current_agent_id()
with self._lock:
if agent_id not in self._sessions_by_agent:
self._sessions_by_agent[agent_id] = {}
return self._sessions_by_agent[agent_id]
def create_session(
self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30
) -> dict[str, Any]:
if session_id is None:
session_id = self.default_session_id
sessions = self._get_agent_sessions()
with self._lock:
if session_id in self.sessions:
if session_id in sessions:
raise ValueError(f"Python session '{session_id}' already exists")
session = PythonInstance(session_id)
self.sessions[session_id] = session
sessions[session_id] = session
if initial_code:
result = session.execute_code(initial_code, timeout)
@@ -49,11 +59,12 @@ class PythonSessionManager:
if not code:
raise ValueError("No code provided for execution")
sessions = self._get_agent_sessions()
with self._lock:
if session_id not in self.sessions:
if session_id not in sessions:
raise ValueError(f"Python session '{session_id}' not found")
session = self.sessions[session_id]
session = sessions[session_id]
result = session.execute_code(code, timeout)
result["message"] = f"Code executed in session '{session_id}'"
@@ -63,11 +74,12 @@ class PythonSessionManager:
if session_id is None:
session_id = self.default_session_id
sessions = self._get_agent_sessions()
with self._lock:
if session_id not in self.sessions:
if session_id not in sessions:
raise ValueError(f"Python session '{session_id}' not found")
session = self.sessions.pop(session_id)
session = sessions.pop(session_id)
session.close()
return {
@@ -77,9 +89,10 @@ class PythonSessionManager:
}
def list_sessions(self) -> dict[str, Any]:
sessions = self._get_agent_sessions()
with self._lock:
session_info = {}
for sid, session in self.sessions.items():
for sid, session in sessions.items():
session_info[sid] = {
"is_running": session.is_running,
"is_alive": session.is_alive(),
@@ -87,24 +100,35 @@ class PythonSessionManager:
return {"sessions": session_info, "total_count": len(session_info)}
def cleanup_agent(self, agent_id: str) -> None:
with self._lock:
sessions = self._sessions_by_agent.pop(agent_id, {})
for session in sessions.values():
with contextlib.suppress(Exception):
session.close()
def cleanup_dead_sessions(self) -> None:
with self._lock:
dead_sessions = []
for sid, session in self.sessions.items():
if not session.is_alive():
dead_sessions.append(sid)
for sessions in self._sessions_by_agent.values():
dead_sessions = []
for sid, session in sessions.items():
if not session.is_alive():
dead_sessions.append(sid)
for sid in dead_sessions:
session = self.sessions.pop(sid)
with contextlib.suppress(Exception):
session.close()
for sid in dead_sessions:
session = sessions.pop(sid)
with contextlib.suppress(Exception):
session.close()
def close_all_sessions(self) -> None:
with self._lock:
sessions_to_close = list(self.sessions.values())
self.sessions.clear()
all_sessions: list[PythonInstance] = []
for sessions in self._sessions_by_agent.values():
all_sessions.extend(sessions.values())
self._sessions_by_agent.clear()
for session in sessions_to_close:
for session in all_sessions:
with contextlib.suppress(Exception):
session.close()

View File

@@ -3,18 +3,27 @@ import contextlib
import threading
from typing import Any
from strix.tools.context import get_current_agent_id
from .terminal_session import TerminalSession
class TerminalManager:
def __init__(self) -> None:
self.sessions: dict[str, TerminalSession] = {}
self._sessions_by_agent: dict[str, dict[str, TerminalSession]] = {}
self._lock = threading.Lock()
self.default_terminal_id = "default"
self.default_timeout = 30.0
self._register_cleanup_handlers()
def _get_agent_sessions(self) -> dict[str, TerminalSession]:
agent_id = get_current_agent_id()
with self._lock:
if agent_id not in self._sessions_by_agent:
self._sessions_by_agent[agent_id] = {}
return self._sessions_by_agent[agent_id]
def execute_command(
self,
command: str,
@@ -62,24 +71,26 @@ class TerminalManager:
}
def _get_or_create_session(self, terminal_id: str) -> TerminalSession:
sessions = self._get_agent_sessions()
with self._lock:
if terminal_id not in self.sessions:
self.sessions[terminal_id] = TerminalSession(terminal_id)
return self.sessions[terminal_id]
if terminal_id not in sessions:
sessions[terminal_id] = TerminalSession(terminal_id)
return sessions[terminal_id]
def close_session(self, terminal_id: str | None = None) -> dict[str, Any]:
if terminal_id is None:
terminal_id = self.default_terminal_id
sessions = self._get_agent_sessions()
with self._lock:
if terminal_id not in self.sessions:
if terminal_id not in sessions:
return {
"terminal_id": terminal_id,
"message": f"Terminal '{terminal_id}' not found",
"status": "not_found",
}
session = self.sessions.pop(terminal_id)
session = sessions.pop(terminal_id)
try:
session.close()
@@ -97,9 +108,10 @@ class TerminalManager:
}
def list_sessions(self) -> dict[str, Any]:
sessions = self._get_agent_sessions()
with self._lock:
session_info: dict[str, dict[str, Any]] = {}
for tid, session in self.sessions.items():
for tid, session in sessions.items():
session_info[tid] = {
"is_running": session.is_running(),
"working_dir": session.get_working_dir(),
@@ -107,24 +119,35 @@ class TerminalManager:
return {"sessions": session_info, "total_count": len(session_info)}
def cleanup_agent(self, agent_id: str) -> None:
with self._lock:
sessions = self._sessions_by_agent.pop(agent_id, {})
for session in sessions.values():
with contextlib.suppress(Exception):
session.close()
def cleanup_dead_sessions(self) -> None:
with self._lock:
dead_sessions: list[str] = []
for tid, session in self.sessions.items():
if not session.is_running():
dead_sessions.append(tid)
for sessions in self._sessions_by_agent.values():
dead_sessions: list[str] = []
for tid, session in sessions.items():
if not session.is_running():
dead_sessions.append(tid)
for tid in dead_sessions:
session = self.sessions.pop(tid)
with contextlib.suppress(Exception):
session.close()
for tid in dead_sessions:
session = sessions.pop(tid)
with contextlib.suppress(Exception):
session.close()
def close_all_sessions(self) -> None:
with self._lock:
sessions_to_close = list(self.sessions.values())
self.sessions.clear()
all_sessions: list[TerminalSession] = []
for sessions in self._sessions_by_agent.values():
all_sessions.extend(sessions.values())
self._sessions_by_agent.clear()
for session in sessions_to_close:
for session in all_sessions:
with contextlib.suppress(Exception):
session.close()