Running all agents under same container (#12)
This commit is contained in:
@@ -7,19 +7,15 @@ from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import docker
|
||||
from docker.errors import DockerException, NotFound
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from docker.models.containers import Container
|
||||
|
||||
from .runtime import AbstractRuntime, SandboxInfo
|
||||
|
||||
|
||||
STRIX_AGENT_LABEL = "StrixAgent_ID"
|
||||
STRIX_SCAN_LABEL = "StrixScan_ID"
|
||||
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.4")
|
||||
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_initialized_volumes: set[str] = set()
|
||||
|
||||
|
||||
class DockerRuntime(AbstractRuntime):
|
||||
def __init__(self) -> None:
|
||||
@@ -29,9 +25,18 @@ class DockerRuntime(AbstractRuntime):
|
||||
logger.exception("Failed to connect to Docker daemon")
|
||||
raise RuntimeError("Docker is not available or not configured correctly.") from e
|
||||
|
||||
self._scan_container: Container | None = None
|
||||
self._tool_server_port: int | None = None
|
||||
self._tool_server_token: str | None = None
|
||||
|
||||
def _generate_sandbox_token(self) -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def _find_available_port(self) -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return cast("int", s.getsockname()[1])
|
||||
|
||||
def _get_scan_id(self, agent_id: str) -> str:
|
||||
try:
|
||||
from strix.cli.tracer import get_global_tracer
|
||||
@@ -46,37 +51,151 @@ class DockerRuntime(AbstractRuntime):
|
||||
|
||||
return f"scan-{agent_id.split('-')[0]}"
|
||||
|
||||
def _find_available_port(self) -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return cast("int", s.getsockname()[1])
|
||||
def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None:
|
||||
def _validate_image(image: docker.models.images.Image) -> None:
|
||||
if not image.id or not image.attrs:
|
||||
raise ImageNotFound(f"Image {image_name} metadata incomplete")
|
||||
|
||||
def _get_workspace_volume_name(self, scan_id: str) -> str:
|
||||
return f"strix-workspace-{scan_id}"
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
image = self.client.images.get(image_name)
|
||||
_validate_image(image)
|
||||
except ImageNotFound:
|
||||
if attempt == max_retries - 1:
|
||||
logger.exception(f"Image {image_name} not found after {max_retries} attempts")
|
||||
raise
|
||||
logger.warning(f"Image {image_name} not ready, attempt {attempt + 1}/{max_retries}")
|
||||
time.sleep(2**attempt)
|
||||
except DockerException:
|
||||
if attempt == max_retries - 1:
|
||||
logger.exception(f"Failed to verify image {image_name}")
|
||||
raise
|
||||
logger.warning(f"Docker error verifying image, attempt {attempt + 1}/{max_retries}")
|
||||
time.sleep(2**attempt)
|
||||
else:
|
||||
logger.debug(f"Image {image_name} verified as available")
|
||||
return
|
||||
|
||||
def _get_sandbox_by_agent_id(self, agent_id: str) -> Container | None:
|
||||
try:
|
||||
containers = self.client.containers.list(
|
||||
filters={"label": f"{STRIX_AGENT_LABEL}={agent_id}"}
|
||||
)
|
||||
if not containers:
|
||||
return None
|
||||
if len(containers) > 1:
|
||||
logger.warning(
|
||||
"Multiple sandboxes found for agent ID %s, using the first one.", agent_id
|
||||
def _create_container_with_retry(self, scan_id: str, max_retries: int = 3) -> Container:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self._verify_image_available(STRIX_IMAGE)
|
||||
|
||||
caido_port = self._find_available_port()
|
||||
tool_server_port = self._find_available_port()
|
||||
tool_server_token = self._generate_sandbox_token()
|
||||
|
||||
self._tool_server_port = tool_server_port
|
||||
self._tool_server_token = tool_server_token
|
||||
|
||||
container = self.client.containers.run(
|
||||
STRIX_IMAGE,
|
||||
command="sleep infinity",
|
||||
detach=True,
|
||||
name=f"strix-scan-{scan_id}",
|
||||
hostname=f"strix-scan-{scan_id}",
|
||||
ports={
|
||||
f"{caido_port}/tcp": caido_port,
|
||||
f"{tool_server_port}/tcp": tool_server_port,
|
||||
},
|
||||
cap_add=["NET_ADMIN", "NET_RAW"],
|
||||
labels={"strix-scan-id": scan_id},
|
||||
environment={
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"CAIDO_PORT": str(caido_port),
|
||||
"TOOL_SERVER_PORT": str(tool_server_port),
|
||||
"TOOL_SERVER_TOKEN": tool_server_token,
|
||||
},
|
||||
tty=True,
|
||||
)
|
||||
return cast("Container", containers[0])
|
||||
except DockerException as e:
|
||||
logger.warning("Failed to get sandbox by agent ID %s: %s", agent_id, e)
|
||||
return None
|
||||
|
||||
def _ensure_workspace_volume(self, volume_name: str) -> None:
|
||||
self._scan_container = container
|
||||
logger.info("Created container %s for scan %s", container.id, scan_id)
|
||||
|
||||
self._initialize_container(
|
||||
container, caido_port, tool_server_port, tool_server_token
|
||||
)
|
||||
except DockerException as e:
|
||||
last_exception = e
|
||||
if attempt == max_retries - 1:
|
||||
logger.exception(f"Failed to create container after {max_retries} attempts")
|
||||
break
|
||||
|
||||
logger.warning(f"Container creation attempt {attempt + 1}/{max_retries} failed")
|
||||
|
||||
self._tool_server_port = None
|
||||
self._tool_server_token = None
|
||||
|
||||
sleep_time = (2**attempt) + (0.1 * attempt)
|
||||
time.sleep(sleep_time)
|
||||
else:
|
||||
return container
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to create Docker container after {max_retries} attempts: {last_exception}"
|
||||
) from last_exception
|
||||
|
||||
def _get_or_create_scan_container(self, scan_id: str) -> Container:
|
||||
if self._scan_container:
|
||||
try:
|
||||
self._scan_container.reload()
|
||||
if self._scan_container.status == "running":
|
||||
return self._scan_container
|
||||
except NotFound:
|
||||
self._scan_container = None
|
||||
self._tool_server_port = None
|
||||
self._tool_server_token = None
|
||||
|
||||
try:
|
||||
self.client.volumes.get(volume_name)
|
||||
logger.info(f"Using existing workspace volume: {volume_name}")
|
||||
except NotFound:
|
||||
self.client.volumes.create(name=volume_name, driver="local")
|
||||
logger.info(f"Created new workspace volume: {volume_name}")
|
||||
containers = self.client.containers.list(filters={"label": f"strix-scan-id={scan_id}"})
|
||||
if containers:
|
||||
container = cast("Container", containers[0])
|
||||
if container.status != "running":
|
||||
container.start()
|
||||
time.sleep(2)
|
||||
self._scan_container = container
|
||||
|
||||
for env_var in container.attrs["Config"]["Env"]:
|
||||
if env_var.startswith("TOOL_SERVER_PORT="):
|
||||
self._tool_server_port = int(env_var.split("=")[1])
|
||||
elif env_var.startswith("TOOL_SERVER_TOKEN="):
|
||||
self._tool_server_token = env_var.split("=")[1]
|
||||
|
||||
return container
|
||||
except DockerException as e:
|
||||
logger.warning("Failed to find existing container for scan %s: %s", scan_id, e)
|
||||
|
||||
logger.info("Creating new Docker container for scan %s", scan_id)
|
||||
return self._create_container_with_retry(scan_id)
|
||||
|
||||
def _initialize_container(
|
||||
self, container: Container, caido_port: int, tool_server_port: int, tool_server_token: str
|
||||
) -> None:
|
||||
logger.info("Initializing Caido proxy on port %s", caido_port)
|
||||
result = container.exec_run(
|
||||
f"bash -c 'export CAIDO_PORT={caido_port} && /usr/local/bin/docker-entrypoint.sh true'",
|
||||
detach=False,
|
||||
)
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
result = container.exec_run(
|
||||
"bash -c 'source /etc/profile.d/proxy.sh && echo $CAIDO_API_TOKEN'", user="pentester"
|
||||
)
|
||||
caido_token = result.output.decode().strip() if result.exit_code == 0 else ""
|
||||
|
||||
container.exec_run(
|
||||
f"bash -c 'source /etc/profile.d/proxy.sh && cd /app && "
|
||||
f"STRIX_SANDBOX_MODE=true CAIDO_API_TOKEN={caido_token} CAIDO_PORT={caido_port} "
|
||||
f"poetry run python strix/runtime/tool_server.py --token {tool_server_token} "
|
||||
f"--host 0.0.0.0 --port {tool_server_port} &'",
|
||||
detach=True,
|
||||
user="pentester",
|
||||
)
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
def _copy_local_directory_to_container(self, container: Container, local_path: str) -> None:
|
||||
import tarfile
|
||||
@@ -85,10 +204,10 @@ class DockerRuntime(AbstractRuntime):
|
||||
try:
|
||||
local_path_obj = Path(local_path).resolve()
|
||||
if not local_path_obj.exists() or not local_path_obj.is_dir():
|
||||
logger.warning(f"Local path does not exist or is not a directory: {local_path_obj}")
|
||||
logger.warning(f"Local path does not exist or is not directory: {local_path_obj}")
|
||||
return
|
||||
|
||||
logger.info(f"Copying local directory {local_path_obj} to container {container.id}")
|
||||
logger.info(f"Copying local directory {local_path_obj} to container")
|
||||
|
||||
tar_buffer = BytesIO()
|
||||
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
|
||||
@@ -98,18 +217,14 @@ class DockerRuntime(AbstractRuntime):
|
||||
tar.add(item, arcname=arcname)
|
||||
|
||||
tar_buffer.seek(0)
|
||||
|
||||
container.put_archive("/shared_workspace", tar_buffer.getvalue())
|
||||
container.put_archive("/workspace", tar_buffer.getvalue())
|
||||
|
||||
container.exec_run(
|
||||
"chown -R pentester:pentester /shared_workspace && chmod -R 755 /shared_workspace",
|
||||
"chown -R pentester:pentester /workspace && chmod -R 755 /workspace",
|
||||
user="root",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully copied {local_path_obj} to /shared_workspace in container "
|
||||
f"{container.id}"
|
||||
)
|
||||
logger.info("Successfully copied local directory to /workspace")
|
||||
|
||||
except (OSError, DockerException):
|
||||
logger.exception("Failed to copy local directory to container")
|
||||
@@ -117,94 +232,56 @@ class DockerRuntime(AbstractRuntime):
|
||||
async def create_sandbox(
|
||||
self, agent_id: str, existing_token: str | None = None, local_source_path: str | None = None
|
||||
) -> SandboxInfo:
|
||||
sandbox = self._get_sandbox_by_agent_id(agent_id)
|
||||
auth_token = existing_token or self._generate_sandbox_token()
|
||||
|
||||
scan_id = self._get_scan_id(agent_id)
|
||||
volume_name = self._get_workspace_volume_name(scan_id)
|
||||
container = self._get_or_create_scan_container(scan_id)
|
||||
|
||||
self._ensure_workspace_volume(volume_name)
|
||||
source_copied_key = f"_source_copied_{scan_id}"
|
||||
if local_source_path and not hasattr(self, source_copied_key):
|
||||
self._copy_local_directory_to_container(container, local_source_path)
|
||||
setattr(self, source_copied_key, True)
|
||||
|
||||
if not sandbox:
|
||||
logger.info("Creating new Docker sandbox for agent %s", agent_id)
|
||||
try:
|
||||
tool_server_port = self._find_available_port()
|
||||
caido_port = self._find_available_port()
|
||||
|
||||
volumes_config = {volume_name: {"bind": "/shared_workspace", "mode": "rw"}}
|
||||
container_name = f"strix-{agent_id}"
|
||||
|
||||
sandbox = self.client.containers.run(
|
||||
STRIX_IMAGE,
|
||||
command="sleep infinity",
|
||||
detach=True,
|
||||
name=container_name,
|
||||
hostname=container_name,
|
||||
ports={
|
||||
f"{tool_server_port}/tcp": tool_server_port,
|
||||
f"{caido_port}/tcp": caido_port,
|
||||
},
|
||||
cap_add=["NET_ADMIN", "NET_RAW"],
|
||||
labels={
|
||||
STRIX_AGENT_LABEL: agent_id,
|
||||
STRIX_SCAN_LABEL: scan_id,
|
||||
},
|
||||
environment={
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"STRIX_AGENT_ID": agent_id,
|
||||
"STRIX_SANDBOX_TOKEN": auth_token,
|
||||
"STRIX_TOOL_SERVER_PORT": str(tool_server_port),
|
||||
"CAIDO_PORT": str(caido_port),
|
||||
},
|
||||
volumes=volumes_config,
|
||||
tty=True,
|
||||
)
|
||||
logger.info(
|
||||
"Created new sandbox %s for agent %s with shared workspace %s",
|
||||
sandbox.id,
|
||||
agent_id,
|
||||
volume_name,
|
||||
)
|
||||
except DockerException as e:
|
||||
raise RuntimeError(f"Failed to create Docker sandbox: {e}") from e
|
||||
|
||||
assert sandbox is not None
|
||||
if sandbox.status != "running":
|
||||
sandbox.start()
|
||||
time.sleep(15)
|
||||
|
||||
if local_source_path and volume_name not in _initialized_volumes:
|
||||
self._copy_local_directory_to_container(sandbox, local_source_path)
|
||||
_initialized_volumes.add(volume_name)
|
||||
|
||||
sandbox_id = sandbox.id
|
||||
if sandbox_id is None:
|
||||
container_id = container.id
|
||||
if container_id is None:
|
||||
raise RuntimeError("Docker container ID is unexpectedly None")
|
||||
|
||||
tool_server_port_str = sandbox.attrs["Config"]["Env"][
|
||||
next(
|
||||
(
|
||||
i
|
||||
for i, s in enumerate(sandbox.attrs["Config"]["Env"])
|
||||
if s.startswith("STRIX_TOOL_SERVER_PORT=")
|
||||
),
|
||||
-1,
|
||||
)
|
||||
].split("=")[1]
|
||||
tool_server_port = int(tool_server_port_str)
|
||||
token = existing_token if existing_token is not None else self._tool_server_token
|
||||
|
||||
api_url = await self.get_sandbox_url(sandbox_id, tool_server_port)
|
||||
if self._tool_server_port is None or token is None:
|
||||
raise RuntimeError("Tool server not initialized or no token available")
|
||||
|
||||
api_url = await self.get_sandbox_url(container_id, self._tool_server_port)
|
||||
|
||||
await self._register_agent_with_tool_server(api_url, agent_id, token)
|
||||
|
||||
return {
|
||||
"workspace_id": sandbox_id,
|
||||
"workspace_id": container_id,
|
||||
"api_url": api_url,
|
||||
"auth_token": auth_token,
|
||||
"tool_server_port": tool_server_port,
|
||||
"auth_token": token,
|
||||
"tool_server_port": self._tool_server_port,
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
async def get_sandbox_url(self, sandbox_id: str, port: int) -> str:
|
||||
async def _register_agent_with_tool_server(
|
||||
self, api_url: str, agent_id: str, token: str
|
||||
) -> None:
|
||||
import httpx
|
||||
|
||||
try:
|
||||
container = self.client.containers.get(sandbox_id)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{api_url}/register_agent",
|
||||
params={"agent_id": agent_id},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Registered agent {agent_id} with tool server")
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
logger.warning(f"Failed to register agent {agent_id}: {e}")
|
||||
|
||||
async def get_sandbox_url(self, container_id: str, port: int) -> str:
|
||||
try:
|
||||
container = self.client.containers.get(container_id)
|
||||
container.reload()
|
||||
|
||||
host = "localhost"
|
||||
@@ -214,58 +291,25 @@ class DockerRuntime(AbstractRuntime):
|
||||
host = docker_host.split("://")[1].split(":")[0]
|
||||
|
||||
except NotFound:
|
||||
raise ValueError(f"Sandbox {sandbox_id} not found.") from None
|
||||
raise ValueError(f"Container {container_id} not found.") from None
|
||||
except DockerException as e:
|
||||
raise RuntimeError(f"Failed to get sandbox URL for {sandbox_id}: {e}") from e
|
||||
raise RuntimeError(f"Failed to get container URL for {container_id}: {e}") from e
|
||||
else:
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
async def destroy_sandbox(self, sandbox_id: str) -> None:
|
||||
logger.info("Destroying Docker sandbox %s", sandbox_id)
|
||||
async def destroy_sandbox(self, container_id: str) -> None:
|
||||
logger.info("Destroying scan container %s", container_id)
|
||||
try:
|
||||
container = self.client.containers.get(sandbox_id)
|
||||
|
||||
scan_id = None
|
||||
if container.labels and STRIX_SCAN_LABEL in container.labels:
|
||||
scan_id = container.labels[STRIX_SCAN_LABEL]
|
||||
|
||||
container = self.client.containers.get(container_id)
|
||||
container.stop()
|
||||
container.remove()
|
||||
logger.info("Successfully destroyed sandbox %s", sandbox_id)
|
||||
logger.info("Successfully destroyed container %s", container_id)
|
||||
|
||||
if scan_id:
|
||||
await self._cleanup_workspace_if_empty(scan_id)
|
||||
self._scan_container = None
|
||||
self._tool_server_port = None
|
||||
self._tool_server_token = None
|
||||
|
||||
except NotFound:
|
||||
logger.warning("Sandbox %s not found for destruction.", sandbox_id)
|
||||
logger.warning("Container %s not found for destruction.", container_id)
|
||||
except DockerException as e:
|
||||
logger.warning("Failed to destroy sandbox %s: %s", sandbox_id, e)
|
||||
|
||||
async def _cleanup_workspace_if_empty(self, scan_id: str) -> None:
|
||||
try:
|
||||
volume_name = self._get_workspace_volume_name(scan_id)
|
||||
|
||||
containers = self.client.containers.list(
|
||||
all=True, filters={"label": f"{STRIX_SCAN_LABEL}={scan_id}"}
|
||||
)
|
||||
|
||||
if not containers:
|
||||
try:
|
||||
volume = self.client.volumes.get(volume_name)
|
||||
volume.remove()
|
||||
logger.info(
|
||||
f"Cleaned up workspace volume {volume_name} for completed scan {scan_id}"
|
||||
)
|
||||
|
||||
_initialized_volumes.discard(volume_name)
|
||||
|
||||
except NotFound:
|
||||
logger.debug(f"Volume {volume_name} already removed")
|
||||
except DockerException as e:
|
||||
logger.warning(f"Failed to remove volume {volume_name}: {e}")
|
||||
|
||||
except DockerException as e:
|
||||
logger.warning("Error during workspace cleanup for scan %s: %s", scan_id, e)
|
||||
|
||||
async def cleanup_scan_workspace(self, scan_id: str) -> None:
|
||||
await self._cleanup_workspace_if_empty(scan_id)
|
||||
logger.warning("Failed to destroy container %s: %s", container_id, e)
|
||||
|
||||
@@ -7,6 +7,7 @@ class SandboxInfo(TypedDict):
|
||||
api_url: str
|
||||
auth_token: str | None
|
||||
tool_server_port: int
|
||||
agent_id: str
|
||||
|
||||
|
||||
class AbstractRuntime(ABC):
|
||||
@@ -17,9 +18,9 @@ class AbstractRuntime(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_sandbox_url(self, sandbox_id: str, port: int) -> str:
|
||||
async def get_sandbox_url(self, container_id: str, port: int) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def destroy_sandbox(self, sandbox_id: str) -> None:
|
||||
async def destroy_sandbox(self, container_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from multiprocessing import Process, Queue
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import BaseModel, ValidationError
|
||||
@@ -11,20 +19,25 @@ SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||
if not SANDBOX_MODE:
|
||||
raise RuntimeError("Tool server should only run in sandbox mode (STRIX_SANDBOX_MODE=true)")
|
||||
|
||||
EXPECTED_TOKEN = os.getenv("STRIX_SANDBOX_TOKEN")
|
||||
if not EXPECTED_TOKEN:
|
||||
raise RuntimeError("STRIX_SANDBOX_TOKEN environment variable is required in sandbox mode")
|
||||
parser = argparse.ArgumentParser(description="Start Strix tool server")
|
||||
parser.add_argument("--token", required=True, help="Authentication token")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") # nosec
|
||||
parser.add_argument("--port", type=int, required=True, help="Port to bind to")
|
||||
|
||||
args = parser.parse_args()
|
||||
EXPECTED_TOKEN = args.token
|
||||
|
||||
app = FastAPI()
|
||||
logger = logging.getLogger(__name__)
|
||||
security = HTTPBearer()
|
||||
|
||||
security_dependency = Depends(security)
|
||||
|
||||
agent_processes: dict[str, dict[str, Any]] = {}
|
||||
agent_queues: dict[str, dict[str, Queue[Any]]] = {}
|
||||
|
||||
|
||||
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
|
||||
if not credentials or credentials.scheme != "Bearer":
|
||||
logger.warning("Authentication failed: Invalid or missing Bearer token scheme")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication scheme. Bearer token required.",
|
||||
@@ -32,18 +45,17 @@ def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
|
||||
)
|
||||
|
||||
if credentials.credentials != EXPECTED_TOKEN:
|
||||
logger.warning("Authentication failed: Invalid token provided from remote host")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
logger.debug("Authentication successful for tool execution request")
|
||||
return credentials.credentials
|
||||
|
||||
|
||||
class ToolExecutionRequest(BaseModel):
|
||||
agent_id: str
|
||||
tool_name: str
|
||||
kwargs: dict[str, Any]
|
||||
|
||||
@@ -53,45 +65,141 @@ class ToolExecutionResponse(BaseModel):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def agent_worker(_agent_id: str, request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
|
||||
null_handler = logging.NullHandler()
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers = [null_handler]
|
||||
root_logger.setLevel(logging.CRITICAL)
|
||||
|
||||
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
|
||||
from strix.tools.registry import get_tool_by_name
|
||||
|
||||
while True:
|
||||
try:
|
||||
request = request_queue.get()
|
||||
|
||||
if request is None:
|
||||
break
|
||||
|
||||
tool_name = request["tool_name"]
|
||||
kwargs = request["kwargs"]
|
||||
|
||||
try:
|
||||
tool_func = get_tool_by_name(tool_name)
|
||||
if not tool_func:
|
||||
response_queue.put({"error": f"Tool '{tool_name}' not found"})
|
||||
continue
|
||||
|
||||
converted_kwargs = convert_arguments(tool_func, kwargs)
|
||||
result = tool_func(**converted_kwargs)
|
||||
|
||||
response_queue.put({"result": result})
|
||||
|
||||
except (ArgumentConversionError, ValidationError) as e:
|
||||
response_queue.put({"error": f"Invalid arguments: {e}"})
|
||||
except (RuntimeError, ValueError, ImportError) as e:
|
||||
response_queue.put({"error": f"Tool execution error: {e}"})
|
||||
|
||||
except (RuntimeError, ValueError, ImportError) as e:
|
||||
response_queue.put({"error": f"Worker error: {e}"})
|
||||
|
||||
|
||||
def ensure_agent_process(agent_id: str) -> tuple[Queue[Any], Queue[Any]]:
|
||||
if agent_id not in agent_processes:
|
||||
request_queue: Queue[Any] = Queue()
|
||||
response_queue: Queue[Any] = Queue()
|
||||
|
||||
process = Process(
|
||||
target=agent_worker, args=(agent_id, request_queue, response_queue), daemon=True
|
||||
)
|
||||
process.start()
|
||||
|
||||
agent_processes[agent_id] = {"process": process, "pid": process.pid}
|
||||
agent_queues[agent_id] = {"request": request_queue, "response": response_queue}
|
||||
|
||||
return agent_queues[agent_id]["request"], agent_queues[agent_id]["response"]
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolExecutionResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecutionRequest, credentials: HTTPAuthorizationCredentials = security_dependency
|
||||
) -> ToolExecutionResponse:
|
||||
verify_token(credentials)
|
||||
|
||||
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
|
||||
from strix.tools.registry import get_tool_by_name
|
||||
request_queue, response_queue = ensure_agent_process(request.agent_id)
|
||||
|
||||
request_queue.put({"tool_name": request.tool_name, "kwargs": request.kwargs})
|
||||
|
||||
try:
|
||||
tool_func = get_tool_by_name(request.tool_name)
|
||||
if not tool_func:
|
||||
return ToolExecutionResponse(error=f"Tool '{request.tool_name}' not found")
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, response_queue.get)
|
||||
|
||||
converted_kwargs = convert_arguments(tool_func, request.kwargs)
|
||||
if "error" in response:
|
||||
return ToolExecutionResponse(error=response["error"])
|
||||
return ToolExecutionResponse(result=response.get("result"))
|
||||
|
||||
result = tool_func(**converted_kwargs)
|
||||
except (RuntimeError, ValueError, OSError) as e:
|
||||
return ToolExecutionResponse(error=f"Worker error: {e}")
|
||||
|
||||
return ToolExecutionResponse(result=result)
|
||||
|
||||
except (ArgumentConversionError, ValidationError) as e:
|
||||
logger.warning("Invalid tool arguments: %s", e)
|
||||
return ToolExecutionResponse(error=f"Invalid arguments: {e}")
|
||||
except TypeError as e:
|
||||
logger.warning("Tool execution type error: %s", e)
|
||||
return ToolExecutionResponse(error=f"Tool execution error: {e}")
|
||||
except ValueError as e:
|
||||
logger.warning("Tool execution value error: %s", e)
|
||||
return ToolExecutionResponse(error=f"Tool execution error: {e}")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during tool execution")
|
||||
return ToolExecutionResponse(error="Internal server error")
|
||||
@app.post("/register_agent")
|
||||
async def register_agent(
|
||||
agent_id: str, credentials: HTTPAuthorizationCredentials = security_dependency
|
||||
) -> dict[str, str]:
|
||||
verify_token(credentials)
|
||||
|
||||
ensure_agent_process(agent_id)
|
||||
return {"status": "registered", "agent_id": agent_id}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
async def health_check() -> dict[str, Any]:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"sandbox_mode": str(SANDBOX_MODE),
|
||||
"environment": "sandbox" if SANDBOX_MODE else "main",
|
||||
"auth_configured": "true" if EXPECTED_TOKEN else "false",
|
||||
"active_agents": len(agent_processes),
|
||||
"agents": list(agent_processes.keys()),
|
||||
}
|
||||
|
||||
|
||||
def cleanup_all_agents() -> None:
|
||||
for agent_id in list(agent_processes.keys()):
|
||||
try:
|
||||
agent_queues[agent_id]["request"].put(None)
|
||||
process = agent_processes[agent_id]["process"]
|
||||
|
||||
process.join(timeout=1)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join(timeout=1)
|
||||
|
||||
if process.is_alive():
|
||||
process.kill()
|
||||
|
||||
except (BrokenPipeError, EOFError, OSError):
|
||||
pass
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logging.getLogger(__name__).debug(f"Error during agent cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(_signum: int, _frame: Any) -> None:
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN) if hasattr(signal, "SIGPIPE") else None
|
||||
cleanup_all_agents()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if hasattr(signal, "SIGPIPE"):
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
finally:
|
||||
cleanup_all_agents()
|
||||
|
||||
Reference in New Issue
Block a user