Files
strix/strix/runtime/tool_server.py
2025-08-08 20:36:44 -07:00

98 lines
3.4 KiB
Python

import logging
import os
from typing import Any
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, ValidationError
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not SANDBOX_MODE:
raise RuntimeError("Tool server should only run in sandbox mode (STRIX_SANDBOX_MODE=true)")
EXPECTED_TOKEN = os.getenv("STRIX_SANDBOX_TOKEN")
if not EXPECTED_TOKEN:
raise RuntimeError("STRIX_SANDBOX_TOKEN environment variable is required in sandbox mode")
app = FastAPI()
logger = logging.getLogger(__name__)
security = HTTPBearer()
security_dependency = Depends(security)
def verify_token(credentials: HTTPAuthorizationCredentials) -> str:
if not credentials or credentials.scheme != "Bearer":
logger.warning("Authentication failed: Invalid or missing Bearer token scheme")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme. Bearer token required.",
headers={"WWW-Authenticate": "Bearer"},
)
if credentials.credentials != EXPECTED_TOKEN:
logger.warning("Authentication failed: Invalid token provided from remote host")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
logger.debug("Authentication successful for tool execution request")
return credentials.credentials
class ToolExecutionRequest(BaseModel):
tool_name: str
kwargs: dict[str, Any]
class ToolExecutionResponse(BaseModel):
result: Any | None = None
error: str | None = None
@app.post("/execute", response_model=ToolExecutionResponse)
async def execute_tool(
request: ToolExecutionRequest, credentials: HTTPAuthorizationCredentials = security_dependency
) -> ToolExecutionResponse:
verify_token(credentials)
from strix.tools.argument_parser import ArgumentConversionError, convert_arguments
from strix.tools.registry import get_tool_by_name
try:
tool_func = get_tool_by_name(request.tool_name)
if not tool_func:
return ToolExecutionResponse(error=f"Tool '{request.tool_name}' not found")
converted_kwargs = convert_arguments(tool_func, request.kwargs)
result = tool_func(**converted_kwargs)
return ToolExecutionResponse(result=result)
except (ArgumentConversionError, ValidationError) as e:
logger.warning("Invalid tool arguments: %s", e)
return ToolExecutionResponse(error=f"Invalid arguments: {e}")
except TypeError as e:
logger.warning("Tool execution type error: %s", e)
return ToolExecutionResponse(error=f"Tool execution error: {e}")
except ValueError as e:
logger.warning("Tool execution value error: %s", e)
return ToolExecutionResponse(error=f"Tool execution error: {e}")
except Exception:
logger.exception("Unexpected error during tool execution")
return ToolExecutionResponse(error="Internal server error")
@app.get("/health")
async def health_check() -> dict[str, str]:
return {
"status": "healthy",
"sandbox_mode": str(SANDBOX_MODE),
"environment": "sandbox" if SANDBOX_MODE else "main",
"auth_configured": "true" if EXPECTED_TOKEN else "false",
}