Files
strix/strix/tools/registry.py

307 lines
9.0 KiB
Python

import inspect
import logging
import os
from collections.abc import Callable
from functools import wraps
from inspect import signature
from pathlib import Path
from typing import Any
import defusedxml.ElementTree as DefusedET
from strix.utils.resource_paths import get_strix_resource_path
tools: list[dict[str, Any]] = []
_tools_by_name: dict[str, Callable[..., Any]] = {}
_tool_param_schemas: dict[str, dict[str, Any]] = {}
logger = logging.getLogger(__name__)
class ImplementedInClientSideOnlyError(Exception):
def __init__(
self,
message: str = "This tool is implemented in the client side only",
) -> None:
self.message = message
super().__init__(self.message)
def _process_dynamic_content(content: str) -> str:
if "{{DYNAMIC_SKILLS_DESCRIPTION}}" in content:
try:
from strix.skills import generate_skills_description
skills_description = generate_skills_description()
content = content.replace("{{DYNAMIC_SKILLS_DESCRIPTION}}", skills_description)
except ImportError:
logger.warning("Could not import skills utilities for dynamic schema generation")
content = content.replace(
"{{DYNAMIC_SKILLS_DESCRIPTION}}",
"List of skills to load for this agent (max 5). Skill discovery failed.",
)
return content
def _load_xml_schema(path: Path) -> Any:
if not path.exists():
return None
try:
content = path.read_text(encoding="utf-8")
content = _process_dynamic_content(content)
start_tag = '<tool name="'
end_tag = "</tool>"
tools_dict = {}
pos = 0
while True:
start_pos = content.find(start_tag, pos)
if start_pos == -1:
break
name_start = start_pos + len(start_tag)
name_end = content.find('"', name_start)
if name_end == -1:
break
tool_name = content[name_start:name_end]
end_pos = content.find(end_tag, name_end)
if end_pos == -1:
break
end_pos += len(end_tag)
tool_element = content[start_pos:end_pos]
tools_dict[tool_name] = tool_element
pos = end_pos
if pos >= len(content):
break
except (IndexError, ValueError, UnicodeError) as e:
logger.warning(f"Error loading schema file {path}: {e}")
return None
else:
return tools_dict
def _parse_param_schema(tool_xml: str) -> dict[str, Any]:
params: set[str] = set()
required: set[str] = set()
params_start = tool_xml.find("<parameters>")
params_end = tool_xml.find("</parameters>")
if params_start == -1 or params_end == -1:
return {"params": set(), "required": set(), "has_params": False}
params_section = tool_xml[params_start : params_end + len("</parameters>")]
try:
root = DefusedET.fromstring(params_section)
except DefusedET.ParseError:
return {"params": set(), "required": set(), "has_params": False}
for param in root.findall(".//parameter"):
name = param.attrib.get("name")
if not name:
continue
params.add(name)
if param.attrib.get("required", "false").lower() == "true":
required.add(name)
return {"params": params, "required": required, "has_params": bool(params or required)}
def _get_module_name(func: Callable[..., Any]) -> str:
module = inspect.getmodule(func)
if not module:
return "unknown"
module_name = module.__name__
if ".tools." in module_name:
parts = module_name.split(".tools.")[-1].split(".")
if len(parts) >= 1:
return parts[0]
return "unknown"
def _get_schema_path(func: Callable[..., Any]) -> Path | None:
module = inspect.getmodule(func)
if not module or not module.__name__:
return None
module_name = module.__name__
if ".tools." not in module_name:
return None
parts = module_name.split(".tools.")[-1].split(".")
if len(parts) < 2:
return None
folder = parts[0]
file_stem = parts[1]
schema_file = f"{file_stem}_schema.xml"
return get_strix_resource_path("tools", folder, schema_file)
def _is_sandbox_mode() -> bool:
return os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
def _is_browser_disabled() -> bool:
if os.getenv("STRIX_DISABLE_BROWSER", "").lower() == "true":
return True
from strix.config import Config
val: str = Config.load().get("env", {}).get("STRIX_DISABLE_BROWSER", "")
return str(val).lower() == "true"
def _has_perplexity_api() -> bool:
if os.getenv("PERPLEXITY_API_KEY"):
return True
from strix.config import Config
return bool(Config.load().get("env", {}).get("PERPLEXITY_API_KEY"))
def _should_register_tool(
*,
sandbox_execution: bool,
requires_browser_mode: bool,
requires_web_search_mode: bool,
) -> bool:
sandbox_mode = _is_sandbox_mode()
if sandbox_mode and not sandbox_execution:
return False
if requires_browser_mode and _is_browser_disabled():
return False
return not (requires_web_search_mode and not _has_perplexity_api())
def register_tool(
func: Callable[..., Any] | None = None,
*,
sandbox_execution: bool = True,
requires_browser_mode: bool = False,
requires_web_search_mode: bool = False,
) -> Callable[..., Any]:
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
if not _should_register_tool(
sandbox_execution=sandbox_execution,
requires_browser_mode=requires_browser_mode,
requires_web_search_mode=requires_web_search_mode,
):
return f
sandbox_mode = _is_sandbox_mode()
func_dict = {
"name": f.__name__,
"function": f,
"module": _get_module_name(f),
"sandbox_execution": sandbox_execution,
}
if not sandbox_mode:
try:
schema_path = _get_schema_path(f)
xml_tools = _load_xml_schema(schema_path) if schema_path else None
if xml_tools is not None and f.__name__ in xml_tools:
func_dict["xml_schema"] = xml_tools[f.__name__]
else:
func_dict["xml_schema"] = (
f'<tool name="{f.__name__}">'
"<description>Schema not found for tool.</description>"
"</tool>"
)
except (TypeError, FileNotFoundError) as e:
logger.warning(f"Error loading schema for {f.__name__}: {e}")
func_dict["xml_schema"] = (
f'<tool name="{f.__name__}">'
"<description>Error loading schema.</description>"
"</tool>"
)
if not sandbox_mode:
xml_schema = func_dict.get("xml_schema")
param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "")
_tool_param_schemas[str(func_dict["name"])] = param_schema
tools.append(func_dict)
_tools_by_name[str(func_dict["name"])] = f
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return f(*args, **kwargs)
return wrapper
if func is None:
return decorator
return decorator(func)
def get_tool_by_name(name: str) -> Callable[..., Any] | None:
return _tools_by_name.get(name)
def get_tool_names() -> list[str]:
return list(_tools_by_name.keys())
def get_tool_param_schema(name: str) -> dict[str, Any] | None:
return _tool_param_schemas.get(name)
def needs_agent_state(tool_name: str) -> bool:
tool_func = get_tool_by_name(tool_name)
if not tool_func:
return False
sig = signature(tool_func)
return "agent_state" in sig.parameters
def should_execute_in_sandbox(tool_name: str) -> bool:
for tool in tools:
if tool.get("name") == tool_name:
return bool(tool.get("sandbox_execution", True))
return True
def get_tools_prompt() -> str:
tools_by_module: dict[str, list[dict[str, Any]]] = {}
for tool in tools:
module = tool.get("module", "unknown")
if module not in tools_by_module:
tools_by_module[module] = []
tools_by_module[module].append(tool)
xml_sections = []
for module, module_tools in sorted(tools_by_module.items()):
tag_name = f"{module}_tools"
section_parts = [f"<{tag_name}>"]
for tool in module_tools:
tool_xml = tool.get("xml_schema", "")
if tool_xml:
indented_tool = "\n".join(f" {line}" for line in tool_xml.split("\n"))
section_parts.append(indented_tool)
section_parts.append(f"</{tag_name}>")
xml_sections.append("\n".join(section_parts))
return "\n\n".join(xml_sections)
def clear_registry() -> None:
tools.clear()
_tools_by_name.clear()
_tool_param_schemas.clear()