refactor: simplify tool server to asyncio tasks with per-agent isolation
- Replace multiprocessing/threading with single asyncio task per agent - Add task cancellation: new request cancels previous for same agent - Add per-agent state isolation via ContextVar for Terminal, Browser, Python managers - Add posthog telemetry for tool execution errors (timeout, http, sandbox) - Fix proxy manager singleton pattern - Increase client timeout buffer over server timeout - Add context.py to Dockerfile
This commit is contained in:
@@ -172,7 +172,7 @@ COPY strix/config/ /app/strix/config/
|
|||||||
COPY strix/utils/ /app/strix/utils/
|
COPY strix/utils/ /app/strix/utils/
|
||||||
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
|
COPY strix/runtime/tool_server.py strix/runtime/__init__.py strix/runtime/runtime.py /app/strix/runtime/
|
||||||
|
|
||||||
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py /app/strix/tools/
|
COPY strix/tools/__init__.py strix/tools/registry.py strix/tools/executor.py strix/tools/argument_parser.py strix/tools/context.py /app/strix/tools/
|
||||||
|
|
||||||
COPY strix/tools/browser/ /app/strix/tools/browser/
|
COPY strix/tools/browser/ /app/strix/tools/browser/
|
||||||
COPY strix/tools/file_edit/ /app/strix/tools/file_edit/
|
COPY strix/tools/file_edit/ /app/strix/tools/file_edit/
|
||||||
|
|||||||
@@ -2,16 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import queue as stdlib_queue
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
from multiprocessing import Process, Queue
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
@@ -40,13 +34,9 @@ REQUEST_TIMEOUT = args.timeout
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
|
|
||||||
security_dependency = Depends(security)
|
security_dependency = Depends(security)
|
||||||
|
|
||||||
agent_processes: dict[str, dict[str, Any]] = {}
|
agent_tasks: dict[str, asyncio.Task[Any]] = {}
|
||||||
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
|
|
||||||
pending_responses: dict[str, dict[str, asyncio.Future[Any]]] = {}
|
|
||||||
agent_listeners: dict[str, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
|
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
|
||||||
@@ -78,107 +68,19 @@ class ToolExecutionResponse(BaseModel):
|
|||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
|
async def _run_tool(agent_id: str, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||||
null_handler = logging.NullHandler()
|
from strix.tools.argument_parser import convert_arguments
|
||||||
|
from strix.tools.context import set_current_agent_id
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.handlers = [null_handler]
|
|
||||||
root_logger.setLevel(logging.CRITICAL)
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
|
|
||||||
from strix.tools.registry import get_tool_by_name
|
from strix.tools.registry import get_tool_by_name
|
||||||
|
|
||||||
def _execute_request(request: dict[str, Any]) -> None:
|
set_current_agent_id(agent_id)
|
||||||
request_id = request.get("request_id", "")
|
|
||||||
tool_name = request["tool_name"]
|
|
||||||
kwargs = request["kwargs"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
tool_func = get_tool_by_name(tool_name)
|
tool_func = get_tool_by_name(tool_name)
|
||||||
if not tool_func:
|
if not tool_func:
|
||||||
response_queue.put(
|
raise ValueError(f"Tool '{tool_name}' not found")
|
||||||
{"request_id": request_id, "error": f"Tool '{tool_name}' not found"}
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
converted_kwargs = convert_arguments(tool_func, kwargs)
|
converted_kwargs = convert_arguments(tool_func, kwargs)
|
||||||
result = tool_func(**converted_kwargs)
|
return await asyncio.to_thread(tool_func, **converted_kwargs)
|
||||||
|
|
||||||
response_queue.put({"request_id": request_id, "result": result})
|
|
||||||
|
|
||||||
except (ArgumentConversionError, ValidationError) as e:
|
|
||||||
response_queue.put({"request_id": request_id, "error": f"Invalid arguments: {e}"})
|
|
||||||
except (RuntimeError, ValueError, ImportError) as e:
|
|
||||||
response_queue.put({"request_id": request_id, "error": f"Tool execution error: {e}"})
|
|
||||||
except Exception as e: # noqa: BLE001
|
|
||||||
response_queue.put({"request_id": request_id, "error": f"Unexpected error: {e}"})
|
|
||||||
|
|
||||||
with ThreadPoolExecutor() as executor:
|
|
||||||
while True:
|
|
||||||
request = None
|
|
||||||
try:
|
|
||||||
request = request_queue.get()
|
|
||||||
|
|
||||||
if request is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
executor.submit(_execute_request, request)
|
|
||||||
|
|
||||||
except (RuntimeError, ValueError, ImportError) as e:
|
|
||||||
req_id = request.get("request_id", "") if request else ""
|
|
||||||
response_queue.put({"request_id": req_id, "error": f"Worker error: {e}"})
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_response_listener(agent_id: str, response_queue: Queue[Any]) -> None:
|
|
||||||
if agent_id in agent_listeners:
|
|
||||||
return
|
|
||||||
|
|
||||||
stop_event = threading.Event()
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
|
|
||||||
def _listener() -> None:
|
|
||||||
while not stop_event.is_set():
|
|
||||||
try:
|
|
||||||
item = response_queue.get(timeout=0.5)
|
|
||||||
except stdlib_queue.Empty:
|
|
||||||
continue
|
|
||||||
except (BrokenPipeError, EOFError):
|
|
||||||
break
|
|
||||||
|
|
||||||
request_id = item.get("request_id")
|
|
||||||
if not request_id or agent_id not in pending_responses:
|
|
||||||
continue
|
|
||||||
|
|
||||||
future = pending_responses[agent_id].pop(request_id, None)
|
|
||||||
if future and not future.done():
|
|
||||||
with contextlib.suppress(RuntimeError):
|
|
||||||
loop.call_soon_threadsafe(future.set_result, item)
|
|
||||||
|
|
||||||
listener_thread = threading.Thread(target=_listener, daemon=True)
|
|
||||||
listener_thread.start()
|
|
||||||
|
|
||||||
agent_listeners[agent_id] = {"thread": listener_thread, "stop_event": stop_event}
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
|
|
||||||
if agent_id not in agent_processes:
|
|
||||||
request_queue: Queue[Any] = Queue()
|
|
||||||
response_queue: Queue[Any] = Queue()
|
|
||||||
|
|
||||||
process = Process(
|
|
||||||
target=agent_worker, args=(agent_id, request_queue, response_queue), daemon=True
|
|
||||||
)
|
|
||||||
process.start()
|
|
||||||
|
|
||||||
agent_processes[agent_id] = {"process": process, "pid": process.pid}
|
|
||||||
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
|
|
||||||
pending_responses[agent_id] = {}
|
|
||||||
|
|
||||||
_ensure_response_listener(agent_id, response_queue)
|
|
||||||
|
|
||||||
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/execute", response_model=ToolExecutionResponse)
|
@app.post("/execute", response_model=ToolExecutionResponse)
|
||||||
@@ -187,33 +89,42 @@ async def execute_tool(
|
|||||||
) -> ToolExecutionResponse:
|
) -> ToolExecutionResponse:
|
||||||
verify_token(credentials)
|
verify_token(credentials)
|
||||||
|
|
||||||
request_queue, _response_queue = ensure_agent_process(request.agent_id)
|
agent_id = request.agent_id
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
if agent_id in agent_tasks:
|
||||||
req_id = uuid4().hex
|
old_task = agent_tasks[agent_id]
|
||||||
future: asyncio.Future[Any] = loop.create_future()
|
if not old_task.done():
|
||||||
pending_responses[request.agent_id][req_id] = future
|
old_task.cancel()
|
||||||
|
|
||||||
request_queue.put(
|
task = asyncio.create_task(
|
||||||
{
|
asyncio.wait_for(
|
||||||
"request_id": req_id,
|
_run_tool(agent_id, request.tool_name, request.kwargs), timeout=REQUEST_TIMEOUT
|
||||||
"tool_name": request.tool_name,
|
|
||||||
"kwargs": request.kwargs,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
agent_tasks[agent_id] = task
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await asyncio.wait_for(future, timeout=REQUEST_TIMEOUT)
|
result = await task
|
||||||
|
return ToolExecutionResponse(result=result)
|
||||||
|
|
||||||
if "error" in response:
|
except asyncio.CancelledError:
|
||||||
return ToolExecutionResponse(error=response["error"])
|
return ToolExecutionResponse(error="Cancelled by newer request")
|
||||||
return ToolExecutionResponse(result=response.get("result"))
|
|
||||||
|
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
pending_responses[request.agent_id].pop(req_id, None)
|
return ToolExecutionResponse(error=f"Tool timed out after {REQUEST_TIMEOUT}s")
|
||||||
return ToolExecutionResponse(error=f"Request timed out after {REQUEST_TIMEOUT} seconds")
|
|
||||||
except (RuntimeError, ValueError, OSError) as e:
|
except ValidationError as e:
|
||||||
return ToolExecutionResponse(error=f"Worker error: {e}")
|
return ToolExecutionResponse(error=f"Invalid arguments: {e}")
|
||||||
|
|
||||||
|
except (ValueError, RuntimeError, ImportError) as e:
|
||||||
|
return ToolExecutionResponse(error=f"Tool execution error: {e}")
|
||||||
|
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return ToolExecutionResponse(error=f"Unexpected error: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if agent_tasks.get(agent_id) is task:
|
||||||
|
del agent_tasks[agent_id]
|
||||||
|
|
||||||
|
|
||||||
@app.post("/register_agent")
|
@app.post("/register_agent")
|
||||||
@@ -221,8 +132,6 @@ async def register_agent(
|
|||||||
agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency
|
agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
verify_token(credentials)
|
verify_token(credentials)
|
||||||
|
|
||||||
ensure_agent_process(agent_id)
|
|
||||||
return {"status": "registered", "agent_id": agent_id}
|
return {"status": "registered", "agent_id": agent_id}
|
||||||
|
|
||||||
|
|
||||||
@@ -233,42 +142,16 @@ async def health_check() -> dict[str, Any]:
|
|||||||
"sandbox_mode": str(SANDBOX_MODE),
|
"sandbox_mode": str(SANDBOX_MODE),
|
||||||
"environment": "sandbox" if SANDBOX_MODE else "main",
|
"environment": "sandbox" if SANDBOX_MODE else "main",
|
||||||
"auth_configured": "true" if EXPECTED_TOKEN else "false",
|
"auth_configured": "true" if EXPECTED_TOKEN else "false",
|
||||||
"active_agents": len(agent_processes),
|
"active_agents": len(agent_tasks),
|
||||||
"agents": list(agent_processes.keys()),
|
"agents": list(agent_tasks.keys()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def cleanup_all_agents() -> None:
|
|
||||||
for agent_id in list(agent_processes.keys()):
|
|
||||||
try:
|
|
||||||
if agent_id in agent_listeners:
|
|
||||||
agent_listeners[agent_id]["stop_event"].set()
|
|
||||||
|
|
||||||
agent_queues[agent_id]["request"].put(None)
|
|
||||||
process = agent_processes[agent_id]["process"]
|
|
||||||
|
|
||||||
process.join(timeout=1)
|
|
||||||
|
|
||||||
if process.is_alive():
|
|
||||||
process.terminate()
|
|
||||||
process.join(timeout=1)
|
|
||||||
|
|
||||||
if process.is_alive():
|
|
||||||
process.kill()
|
|
||||||
|
|
||||||
if agent_id in agent_listeners:
|
|
||||||
listener_thread = agent_listeners[agent_id]["thread"]
|
|
||||||
listener_thread.join(timeout=0.5)
|
|
||||||
|
|
||||||
except (BrokenPipeError, EOFError, OSError):
|
|
||||||
pass
|
|
||||||
except (RuntimeError, ValueError) as e:
|
|
||||||
logging.getLogger(__name__).debug(f"Error during agent cleanup: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(_signum: int, _frame: Any) -> None:
|
def signal_handler(_signum: int, _frame: Any) -> None:
|
||||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN) if hasattr(signal, "SIGPIPE") else None
|
if hasattr(signal, "SIGPIPE"):
|
||||||
cleanup_all_agents()
|
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||||
|
for task in agent_tasks.values():
|
||||||
|
task.cancel()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
@@ -279,7 +162,4 @@ signal.signal(signal.SIGTERM, signal_handler)
|
|||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
finally:
|
|
||||||
cleanup_all_agents()
|
|
||||||
|
|||||||
@@ -3,39 +3,54 @@ import contextlib
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from strix.tools.context import get_current_agent_id
|
||||||
|
|
||||||
from .browser_instance import BrowserInstance
|
from .browser_instance import BrowserInstance
|
||||||
|
|
||||||
|
|
||||||
class BrowserTabManager:
|
class BrowserTabManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.browser_instance: BrowserInstance | None = None
|
self._browsers_by_agent: dict[str, BrowserInstance] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
self._register_cleanup_handlers()
|
self._register_cleanup_handlers()
|
||||||
|
|
||||||
|
def _get_agent_browser(self) -> BrowserInstance | None:
|
||||||
|
agent_id = get_current_agent_id()
|
||||||
|
with self._lock:
|
||||||
|
return self._browsers_by_agent.get(agent_id)
|
||||||
|
|
||||||
|
def _set_agent_browser(self, browser: BrowserInstance | None) -> None:
|
||||||
|
agent_id = get_current_agent_id()
|
||||||
|
with self._lock:
|
||||||
|
if browser is None:
|
||||||
|
self._browsers_by_agent.pop(agent_id, None)
|
||||||
|
else:
|
||||||
|
self._browsers_by_agent[agent_id] = browser
|
||||||
|
|
||||||
def launch_browser(self, url: str | None = None) -> dict[str, Any]:
|
def launch_browser(self, url: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self.browser_instance is not None:
|
agent_id = get_current_agent_id()
|
||||||
|
if agent_id in self._browsers_by_agent:
|
||||||
raise ValueError("Browser is already launched")
|
raise ValueError("Browser is already launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.browser_instance = BrowserInstance()
|
browser = BrowserInstance()
|
||||||
result = self.browser_instance.launch(url)
|
result = browser.launch(url)
|
||||||
|
self._browsers_by_agent[agent_id] = browser
|
||||||
result["message"] = "Browser launched successfully"
|
result["message"] = "Browser launched successfully"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
if self.browser_instance:
|
|
||||||
self.browser_instance = None
|
|
||||||
raise RuntimeError(f"Failed to launch browser: {e}") from e
|
raise RuntimeError(f"Failed to launch browser: {e}") from e
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
|
def goto_url(self, url: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.goto(url, tab_id)
|
result = browser.goto(url, tab_id)
|
||||||
result["message"] = f"Navigated to {url}"
|
result["message"] = f"Navigated to {url}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to navigate to URL: {e}") from e
|
raise RuntimeError(f"Failed to navigate to URL: {e}") from e
|
||||||
@@ -43,12 +58,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
def click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.click(coordinate, tab_id)
|
result = browser.click(coordinate, tab_id)
|
||||||
result["message"] = f"Clicked at {coordinate}"
|
result["message"] = f"Clicked at {coordinate}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to click: {e}") from e
|
raise RuntimeError(f"Failed to click: {e}") from e
|
||||||
@@ -56,12 +71,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
|
def type_text(self, text: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.type_text(text, tab_id)
|
result = browser.type_text(text, tab_id)
|
||||||
result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}"
|
result["message"] = f"Typed text: {text[:50]}{'...' if len(text) > 50 else ''}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to type text: {e}") from e
|
raise RuntimeError(f"Failed to type text: {e}") from e
|
||||||
@@ -69,12 +84,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
|
def scroll(self, direction: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.scroll(direction, tab_id)
|
result = browser.scroll(direction, tab_id)
|
||||||
result["message"] = f"Scrolled {direction}"
|
result["message"] = f"Scrolled {direction}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to scroll: {e}") from e
|
raise RuntimeError(f"Failed to scroll: {e}") from e
|
||||||
@@ -82,12 +97,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def back(self, tab_id: str | None = None) -> dict[str, Any]:
|
def back(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.back(tab_id)
|
result = browser.back(tab_id)
|
||||||
result["message"] = "Navigated back"
|
result["message"] = "Navigated back"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to go back: {e}") from e
|
raise RuntimeError(f"Failed to go back: {e}") from e
|
||||||
@@ -95,12 +110,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def forward(self, tab_id: str | None = None) -> dict[str, Any]:
|
def forward(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.forward(tab_id)
|
result = browser.forward(tab_id)
|
||||||
result["message"] = "Navigated forward"
|
result["message"] = "Navigated forward"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to go forward: {e}") from e
|
raise RuntimeError(f"Failed to go forward: {e}") from e
|
||||||
@@ -108,12 +123,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def new_tab(self, url: str | None = None) -> dict[str, Any]:
|
def new_tab(self, url: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.new_tab(url)
|
result = browser.new_tab(url)
|
||||||
result["message"] = f"Created new tab {result.get('tab_id', '')}"
|
result["message"] = f"Created new tab {result.get('tab_id', '')}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to create new tab: {e}") from e
|
raise RuntimeError(f"Failed to create new tab: {e}") from e
|
||||||
@@ -121,12 +136,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def switch_tab(self, tab_id: str) -> dict[str, Any]:
|
def switch_tab(self, tab_id: str) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.switch_tab(tab_id)
|
result = browser.switch_tab(tab_id)
|
||||||
result["message"] = f"Switched to tab {tab_id}"
|
result["message"] = f"Switched to tab {tab_id}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to switch tab: {e}") from e
|
raise RuntimeError(f"Failed to switch tab: {e}") from e
|
||||||
@@ -134,12 +149,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def close_tab(self, tab_id: str) -> dict[str, Any]:
|
def close_tab(self, tab_id: str) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.close_tab(tab_id)
|
result = browser.close_tab(tab_id)
|
||||||
result["message"] = f"Closed tab {tab_id}"
|
result["message"] = f"Closed tab {tab_id}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to close tab: {e}") from e
|
raise RuntimeError(f"Failed to close tab: {e}") from e
|
||||||
@@ -147,12 +162,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
|
def wait_browser(self, duration: float, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.wait(duration, tab_id)
|
result = browser.wait(duration, tab_id)
|
||||||
result["message"] = f"Waited {duration}s"
|
result["message"] = f"Waited {duration}s"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to wait: {e}") from e
|
raise RuntimeError(f"Failed to wait: {e}") from e
|
||||||
@@ -160,12 +175,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
|
def execute_js(self, js_code: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.execute_js(js_code, tab_id)
|
result = browser.execute_js(js_code, tab_id)
|
||||||
result["message"] = "JavaScript executed successfully"
|
result["message"] = "JavaScript executed successfully"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to execute JavaScript: {e}") from e
|
raise RuntimeError(f"Failed to execute JavaScript: {e}") from e
|
||||||
@@ -173,12 +188,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
def double_click(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.double_click(coordinate, tab_id)
|
result = browser.double_click(coordinate, tab_id)
|
||||||
result["message"] = f"Double clicked at {coordinate}"
|
result["message"] = f"Double clicked at {coordinate}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to double click: {e}") from e
|
raise RuntimeError(f"Failed to double click: {e}") from e
|
||||||
@@ -186,12 +201,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
def hover(self, coordinate: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.hover(coordinate, tab_id)
|
result = browser.hover(coordinate, tab_id)
|
||||||
result["message"] = f"Hovered at {coordinate}"
|
result["message"] = f"Hovered at {coordinate}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to hover: {e}") from e
|
raise RuntimeError(f"Failed to hover: {e}") from e
|
||||||
@@ -199,12 +214,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
|
def press_key(self, key: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.press_key(key, tab_id)
|
result = browser.press_key(key, tab_id)
|
||||||
result["message"] = f"Pressed key {key}"
|
result["message"] = f"Pressed key {key}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to press key: {e}") from e
|
raise RuntimeError(f"Failed to press key: {e}") from e
|
||||||
@@ -212,12 +227,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
|
def save_pdf(self, file_path: str, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.save_pdf(file_path, tab_id)
|
result = browser.save_pdf(file_path, tab_id)
|
||||||
result["message"] = f"Page saved as PDF: {file_path}"
|
result["message"] = f"Page saved as PDF: {file_path}"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to save PDF: {e}") from e
|
raise RuntimeError(f"Failed to save PDF: {e}") from e
|
||||||
@@ -225,12 +240,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
|
def get_console_logs(self, tab_id: str | None = None, clear: bool = False) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.get_console_logs(tab_id, clear)
|
result = browser.get_console_logs(tab_id, clear)
|
||||||
action_text = "cleared and retrieved" if clear else "retrieved"
|
action_text = "cleared and retrieved" if clear else "retrieved"
|
||||||
|
|
||||||
logs = result.get("console_logs", [])
|
logs = result.get("console_logs", [])
|
||||||
@@ -247,12 +262,12 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
|
def view_source(self, tab_id: str | None = None) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.browser_instance.view_source(tab_id)
|
result = browser.view_source(tab_id)
|
||||||
result["message"] = "Page source retrieved"
|
result["message"] = "Page source retrieved"
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to get page source: {e}") from e
|
raise RuntimeError(f"Failed to get page source: {e}") from e
|
||||||
@@ -260,18 +275,18 @@ class BrowserTabManager:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def list_tabs(self) -> dict[str, Any]:
|
def list_tabs(self) -> dict[str, Any]:
|
||||||
with self._lock:
|
browser = self._get_agent_browser()
|
||||||
if self.browser_instance is None:
|
if browser is None:
|
||||||
return {"tabs": {}, "total_count": 0, "current_tab": None}
|
return {"tabs": {}, "total_count": 0, "current_tab": None}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tab_info = {}
|
tab_info = {}
|
||||||
for tid, tab_page in self.browser_instance.pages.items():
|
for tid, tab_page in browser.pages.items():
|
||||||
try:
|
try:
|
||||||
tab_info[tid] = {
|
tab_info[tid] = {
|
||||||
"url": tab_page.url,
|
"url": tab_page.url,
|
||||||
"title": "Unknown" if tab_page.is_closed() else "Active",
|
"title": "Unknown" if tab_page.is_closed() else "Active",
|
||||||
"is_current": tid == self.browser_instance.current_page_id,
|
"is_current": tid == browser.current_page_id,
|
||||||
}
|
}
|
||||||
except (AttributeError, RuntimeError):
|
except (AttributeError, RuntimeError):
|
||||||
tab_info[tid] = {
|
tab_info[tid] = {
|
||||||
@@ -283,19 +298,20 @@ class BrowserTabManager:
|
|||||||
return {
|
return {
|
||||||
"tabs": tab_info,
|
"tabs": tab_info,
|
||||||
"total_count": len(tab_info),
|
"total_count": len(tab_info),
|
||||||
"current_tab": self.browser_instance.current_page_id,
|
"current_tab": browser.current_page_id,
|
||||||
}
|
}
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to list tabs: {e}") from e
|
raise RuntimeError(f"Failed to list tabs: {e}") from e
|
||||||
|
|
||||||
def close_browser(self) -> dict[str, Any]:
|
def close_browser(self) -> dict[str, Any]:
|
||||||
|
agent_id = get_current_agent_id()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self.browser_instance is None:
|
browser = self._browsers_by_agent.pop(agent_id, None)
|
||||||
|
if browser is None:
|
||||||
raise ValueError("Browser not launched")
|
raise ValueError("Browser not launched")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.browser_instance.close()
|
browser.close()
|
||||||
self.browser_instance = None
|
|
||||||
except (OSError, ValueError, RuntimeError) as e:
|
except (OSError, ValueError, RuntimeError) as e:
|
||||||
raise RuntimeError(f"Failed to close browser: {e}") from e
|
raise RuntimeError(f"Failed to close browser: {e}") from e
|
||||||
else:
|
else:
|
||||||
@@ -305,19 +321,34 @@ class BrowserTabManager:
|
|||||||
"is_running": False,
|
"is_running": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def cleanup_agent(self, agent_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
browser = self._browsers_by_agent.pop(agent_id, None)
|
||||||
|
|
||||||
|
if browser:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
browser.close()
|
||||||
|
|
||||||
def cleanup_dead_browser(self) -> None:
|
def cleanup_dead_browser(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self.browser_instance and not self.browser_instance.is_alive():
|
dead_agents = []
|
||||||
|
for agent_id, browser in self._browsers_by_agent.items():
|
||||||
|
if not browser.is_alive():
|
||||||
|
dead_agents.append(agent_id)
|
||||||
|
|
||||||
|
for agent_id in dead_agents:
|
||||||
|
browser = self._browsers_by_agent.pop(agent_id)
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self.browser_instance.close()
|
browser.close()
|
||||||
self.browser_instance = None
|
|
||||||
|
|
||||||
def close_all(self) -> None:
|
def close_all(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self.browser_instance:
|
browsers = list(self._browsers_by_agent.values())
|
||||||
|
self._browsers_by_agent.clear()
|
||||||
|
|
||||||
|
for browser in browsers:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self.browser_instance.close()
|
browser.close()
|
||||||
self.browser_instance = None
|
|
||||||
|
|
||||||
def _register_cleanup_handlers(self) -> None:
|
def _register_cleanup_handlers(self) -> None:
|
||||||
atexit.register(self.close_all)
|
atexit.register(self.close_all)
|
||||||
|
|||||||
12
strix/tools/context.py
Normal file
12
strix/tools/context.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
|
||||||
|
current_agent_id: ContextVar[str] = ContextVar("current_agent_id", default="default")
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_agent_id() -> str:
|
||||||
|
return current_agent_id.get()
|
||||||
|
|
||||||
|
|
||||||
|
def set_current_agent_id(agent_id: str) -> None:
|
||||||
|
current_agent_id.set(agent_id)
|
||||||
@@ -5,6 +5,7 @@ from typing import Any
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from strix.config import Config
|
from strix.config import Config
|
||||||
|
from strix.telemetry import posthog
|
||||||
|
|
||||||
|
|
||||||
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
|
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
|
||||||
@@ -20,7 +21,8 @@ from .registry import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
|
_SERVER_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout") or "120")
|
||||||
|
SANDBOX_EXECUTION_TIMEOUT = _SERVER_TIMEOUT + 30
|
||||||
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
|
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")
|
||||||
|
|
||||||
|
|
||||||
@@ -82,14 +84,17 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
if response_data.get("error"):
|
if response_data.get("error"):
|
||||||
|
posthog.error("tool_execution_error", f"{tool_name}: {response_data['error']}")
|
||||||
raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
|
raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
|
||||||
return response_data.get("result")
|
return response_data.get("result")
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
posthog.error("tool_http_error", f"{tool_name}: HTTP {e.response.status_code}")
|
||||||
if e.response.status_code == 401:
|
if e.response.status_code == 401:
|
||||||
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
|
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
|
||||||
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
|
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
error_type = type(e).__name__
|
error_type = type(e).__name__
|
||||||
|
posthog.error("tool_request_error", f"{tool_name}: {error_type}")
|
||||||
raise RuntimeError(f"Request error calling tool server: {error_type}") from e
|
raise RuntimeError(f"Request error calling tool server: {error_type}") from e
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -785,6 +785,7 @@ _PROXY_MANAGER: ProxyManager | None = None
|
|||||||
|
|
||||||
|
|
||||||
def get_proxy_manager() -> ProxyManager:
|
def get_proxy_manager() -> ProxyManager:
|
||||||
|
global _PROXY_MANAGER # noqa: PLW0603
|
||||||
if _PROXY_MANAGER is None:
|
if _PROXY_MANAGER is None:
|
||||||
return ProxyManager()
|
_PROXY_MANAGER = ProxyManager()
|
||||||
return _PROXY_MANAGER
|
return _PROXY_MANAGER
|
||||||
|
|||||||
@@ -3,29 +3,39 @@ import contextlib
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from strix.tools.context import get_current_agent_id
|
||||||
|
|
||||||
from .python_instance import PythonInstance
|
from .python_instance import PythonInstance
|
||||||
|
|
||||||
|
|
||||||
class PythonSessionManager:
|
class PythonSessionManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.sessions: dict[str, PythonInstance] = {}
|
self._sessions_by_agent: dict[str, dict[str, PythonInstance]] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self.default_session_id = "default"
|
self.default_session_id = "default"
|
||||||
|
|
||||||
self._register_cleanup_handlers()
|
self._register_cleanup_handlers()
|
||||||
|
|
||||||
|
def _get_agent_sessions(self) -> dict[str, PythonInstance]:
|
||||||
|
agent_id = get_current_agent_id()
|
||||||
|
with self._lock:
|
||||||
|
if agent_id not in self._sessions_by_agent:
|
||||||
|
self._sessions_by_agent[agent_id] = {}
|
||||||
|
return self._sessions_by_agent[agent_id]
|
||||||
|
|
||||||
def create_session(
|
def create_session(
|
||||||
self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30
|
self, session_id: str | None = None, initial_code: str | None = None, timeout: int = 30
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if session_id is None:
|
if session_id is None:
|
||||||
session_id = self.default_session_id
|
session_id = self.default_session_id
|
||||||
|
|
||||||
|
sessions = self._get_agent_sessions()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if session_id in self.sessions:
|
if session_id in sessions:
|
||||||
raise ValueError(f"Python session '{session_id}' already exists")
|
raise ValueError(f"Python session '{session_id}' already exists")
|
||||||
|
|
||||||
session = PythonInstance(session_id)
|
session = PythonInstance(session_id)
|
||||||
self.sessions[session_id] = session
|
sessions[session_id] = session
|
||||||
|
|
||||||
if initial_code:
|
if initial_code:
|
||||||
result = session.execute_code(initial_code, timeout)
|
result = session.execute_code(initial_code, timeout)
|
||||||
@@ -49,11 +59,12 @@ class PythonSessionManager:
|
|||||||
if not code:
|
if not code:
|
||||||
raise ValueError("No code provided for execution")
|
raise ValueError("No code provided for execution")
|
||||||
|
|
||||||
|
sessions = self._get_agent_sessions()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if session_id not in self.sessions:
|
if session_id not in sessions:
|
||||||
raise ValueError(f"Python session '{session_id}' not found")
|
raise ValueError(f"Python session '{session_id}' not found")
|
||||||
|
|
||||||
session = self.sessions[session_id]
|
session = sessions[session_id]
|
||||||
|
|
||||||
result = session.execute_code(code, timeout)
|
result = session.execute_code(code, timeout)
|
||||||
result["message"] = f"Code executed in session '{session_id}'"
|
result["message"] = f"Code executed in session '{session_id}'"
|
||||||
@@ -63,11 +74,12 @@ class PythonSessionManager:
|
|||||||
if session_id is None:
|
if session_id is None:
|
||||||
session_id = self.default_session_id
|
session_id = self.default_session_id
|
||||||
|
|
||||||
|
sessions = self._get_agent_sessions()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if session_id not in self.sessions:
|
if session_id not in sessions:
|
||||||
raise ValueError(f"Python session '{session_id}' not found")
|
raise ValueError(f"Python session '{session_id}' not found")
|
||||||
|
|
||||||
session = self.sessions.pop(session_id)
|
session = sessions.pop(session_id)
|
||||||
|
|
||||||
session.close()
|
session.close()
|
||||||
return {
|
return {
|
||||||
@@ -77,9 +89,10 @@ class PythonSessionManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_sessions(self) -> dict[str, Any]:
|
def list_sessions(self) -> dict[str, Any]:
|
||||||
|
sessions = self._get_agent_sessions()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
session_info = {}
|
session_info = {}
|
||||||
for sid, session in self.sessions.items():
|
for sid, session in sessions.items():
|
||||||
session_info[sid] = {
|
session_info[sid] = {
|
||||||
"is_running": session.is_running,
|
"is_running": session.is_running,
|
||||||
"is_alive": session.is_alive(),
|
"is_alive": session.is_alive(),
|
||||||
@@ -87,24 +100,35 @@ class PythonSessionManager:
|
|||||||
|
|
||||||
return {"sessions": session_info, "total_count": len(session_info)}
|
return {"sessions": session_info, "total_count": len(session_info)}
|
||||||
|
|
||||||
|
def cleanup_agent(self, agent_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
sessions = self._sessions_by_agent.pop(agent_id, {})
|
||||||
|
|
||||||
|
for session in sessions.values():
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
session.close()
|
||||||
|
|
||||||
def cleanup_dead_sessions(self) -> None:
|
def cleanup_dead_sessions(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
for sessions in self._sessions_by_agent.values():
|
||||||
dead_sessions = []
|
dead_sessions = []
|
||||||
for sid, session in self.sessions.items():
|
for sid, session in sessions.items():
|
||||||
if not session.is_alive():
|
if not session.is_alive():
|
||||||
dead_sessions.append(sid)
|
dead_sessions.append(sid)
|
||||||
|
|
||||||
for sid in dead_sessions:
|
for sid in dead_sessions:
|
||||||
session = self.sessions.pop(sid)
|
session = sessions.pop(sid)
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def close_all_sessions(self) -> None:
|
def close_all_sessions(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
sessions_to_close = list(self.sessions.values())
|
all_sessions: list[PythonInstance] = []
|
||||||
self.sessions.clear()
|
for sessions in self._sessions_by_agent.values():
|
||||||
|
all_sessions.extend(sessions.values())
|
||||||
|
self._sessions_by_agent.clear()
|
||||||
|
|
||||||
for session in sessions_to_close:
|
for session in all_sessions:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|||||||
@@ -3,18 +3,27 @@ import contextlib
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from strix.tools.context import get_current_agent_id
|
||||||
|
|
||||||
from .terminal_session import TerminalSession
|
from .terminal_session import TerminalSession
|
||||||
|
|
||||||
|
|
||||||
class TerminalManager:
|
class TerminalManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.sessions: dict[str, TerminalSession] = {}
|
self._sessions_by_agent: dict[str, dict[str, TerminalSession]] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self.default_terminal_id = "default"
|
self.default_terminal_id = "default"
|
||||||
self.default_timeout = 30.0
|
self.default_timeout = 30.0
|
||||||
|
|
||||||
self._register_cleanup_handlers()
|
self._register_cleanup_handlers()
|
||||||
|
|
||||||
|
def _get_agent_sessions(self) -> dict[str, TerminalSession]:
|
||||||
|
agent_id = get_current_agent_id()
|
||||||
|
with self._lock:
|
||||||
|
if agent_id not in self._sessions_by_agent:
|
||||||
|
self._sessions_by_agent[agent_id] = {}
|
||||||
|
return self._sessions_by_agent[agent_id]
|
||||||
|
|
||||||
def execute_command(
|
def execute_command(
|
||||||
self,
|
self,
|
||||||
command: str,
|
command: str,
|
||||||
@@ -62,24 +71,26 @@ class TerminalManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _get_or_create_session(self, terminal_id: str) -> TerminalSession:
|
def _get_or_create_session(self, terminal_id: str) -> TerminalSession:
|
||||||
|
sessions = self._get_agent_sessions()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if terminal_id not in self.sessions:
|
if terminal_id not in sessions:
|
||||||
self.sessions[terminal_id] = TerminalSession(terminal_id)
|
sessions[terminal_id] = TerminalSession(terminal_id)
|
||||||
return self.sessions[terminal_id]
|
return sessions[terminal_id]
|
||||||
|
|
||||||
def close_session(self, terminal_id: str | None = None) -> dict[str, Any]:
|
def close_session(self, terminal_id: str | None = None) -> dict[str, Any]:
|
||||||
if terminal_id is None:
|
if terminal_id is None:
|
||||||
terminal_id = self.default_terminal_id
|
terminal_id = self.default_terminal_id
|
||||||
|
|
||||||
|
sessions = self._get_agent_sessions()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if terminal_id not in self.sessions:
|
if terminal_id not in sessions:
|
||||||
return {
|
return {
|
||||||
"terminal_id": terminal_id,
|
"terminal_id": terminal_id,
|
||||||
"message": f"Terminal '{terminal_id}' not found",
|
"message": f"Terminal '{terminal_id}' not found",
|
||||||
"status": "not_found",
|
"status": "not_found",
|
||||||
}
|
}
|
||||||
|
|
||||||
session = self.sessions.pop(terminal_id)
|
session = sessions.pop(terminal_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.close()
|
session.close()
|
||||||
@@ -97,9 +108,10 @@ class TerminalManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_sessions(self) -> dict[str, Any]:
|
def list_sessions(self) -> dict[str, Any]:
|
||||||
|
sessions = self._get_agent_sessions()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
session_info: dict[str, dict[str, Any]] = {}
|
session_info: dict[str, dict[str, Any]] = {}
|
||||||
for tid, session in self.sessions.items():
|
for tid, session in sessions.items():
|
||||||
session_info[tid] = {
|
session_info[tid] = {
|
||||||
"is_running": session.is_running(),
|
"is_running": session.is_running(),
|
||||||
"working_dir": session.get_working_dir(),
|
"working_dir": session.get_working_dir(),
|
||||||
@@ -107,24 +119,35 @@ class TerminalManager:
|
|||||||
|
|
||||||
return {"sessions": session_info, "total_count": len(session_info)}
|
return {"sessions": session_info, "total_count": len(session_info)}
|
||||||
|
|
||||||
|
def cleanup_agent(self, agent_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
sessions = self._sessions_by_agent.pop(agent_id, {})
|
||||||
|
|
||||||
|
for session in sessions.values():
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
session.close()
|
||||||
|
|
||||||
def cleanup_dead_sessions(self) -> None:
|
def cleanup_dead_sessions(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
for sessions in self._sessions_by_agent.values():
|
||||||
dead_sessions: list[str] = []
|
dead_sessions: list[str] = []
|
||||||
for tid, session in self.sessions.items():
|
for tid, session in sessions.items():
|
||||||
if not session.is_running():
|
if not session.is_running():
|
||||||
dead_sessions.append(tid)
|
dead_sessions.append(tid)
|
||||||
|
|
||||||
for tid in dead_sessions:
|
for tid in dead_sessions:
|
||||||
session = self.sessions.pop(tid)
|
session = sessions.pop(tid)
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def close_all_sessions(self) -> None:
|
def close_all_sessions(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
sessions_to_close = list(self.sessions.values())
|
all_sessions: list[TerminalSession] = []
|
||||||
self.sessions.clear()
|
for sessions in self._sessions_by_agent.values():
|
||||||
|
all_sessions.extend(sessions.values())
|
||||||
|
self._sessions_by_agent.clear()
|
||||||
|
|
||||||
for session in sessions_to_close:
|
for session in all_sessions:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user