fix(agent): fix tool schemas not retrieved on pyinstaller binary and validate tool call args
This commit is contained in:
@@ -14,6 +14,7 @@ from .argument_parser import convert_arguments
|
||||
from .registry import (
|
||||
get_tool_by_name,
|
||||
get_tool_names,
|
||||
get_tool_param_schema,
|
||||
needs_agent_state,
|
||||
should_execute_in_sandbox,
|
||||
)
|
||||
@@ -110,14 +111,51 @@ async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwarg
|
||||
|
||||
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
|
||||
if tool_name is None:
|
||||
return False, "Tool name is missing"
|
||||
available = ", ".join(sorted(get_tool_names()))
|
||||
return False, f"Tool name is missing. Available tools: {available}"
|
||||
|
||||
if tool_name not in get_tool_names():
|
||||
return False, f"Tool '{tool_name}' is not available"
|
||||
available = ", ".join(sorted(get_tool_names()))
|
||||
return False, f"Tool '{tool_name}' is not available. Available tools: {available}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def _validate_tool_arguments(tool_name: str, kwargs: dict[str, Any]) -> str | None:
|
||||
param_schema = get_tool_param_schema(tool_name)
|
||||
if not param_schema or not param_schema.get("has_params"):
|
||||
return None
|
||||
|
||||
allowed_params: set[str] = param_schema.get("params", set())
|
||||
required_params: set[str] = param_schema.get("required", set())
|
||||
optional_params = allowed_params - required_params
|
||||
|
||||
schema_hint = _format_schema_hint(tool_name, required_params, optional_params)
|
||||
|
||||
unknown_params = set(kwargs.keys()) - allowed_params
|
||||
if unknown_params:
|
||||
unknown_list = ", ".join(sorted(unknown_params))
|
||||
return f"Tool '{tool_name}' received unknown parameter(s): {unknown_list}\n{schema_hint}"
|
||||
|
||||
missing_required = [
|
||||
param for param in required_params if param not in kwargs or kwargs.get(param) in (None, "")
|
||||
]
|
||||
if missing_required:
|
||||
missing_list = ", ".join(sorted(missing_required))
|
||||
return f"Tool '{tool_name}' missing required parameter(s): {missing_list}\n{schema_hint}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _format_schema_hint(tool_name: str, required: set[str], optional: set[str]) -> str:
|
||||
parts = [f"Valid parameters for '{tool_name}':"]
|
||||
if required:
|
||||
parts.append(f" Required: {', '.join(sorted(required))}")
|
||||
if optional:
|
||||
parts.append(f" Optional: {', '.join(sorted(optional))}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
async def execute_tool_with_validation(
|
||||
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@@ -127,6 +165,10 @@ async def execute_tool_with_validation(
|
||||
|
||||
assert tool_name is not None
|
||||
|
||||
arg_error = _validate_tool_arguments(tool_name, kwargs)
|
||||
if arg_error:
|
||||
return f"Error: {arg_error}"
|
||||
|
||||
try:
|
||||
result = await execute_tool(tool_name, agent_state, **kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
|
||||
@@ -7,9 +7,14 @@ from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import defusedxml.ElementTree as DefusedET
|
||||
|
||||
from strix.utils.resource_paths import get_strix_resource_path
|
||||
|
||||
|
||||
tools: list[dict[str, Any]] = []
|
||||
_tools_by_name: dict[str, Callable[..., Any]] = {}
|
||||
_tool_param_schemas: dict[str, dict[str, Any]] = {}
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -82,6 +87,34 @@ def _load_xml_schema(path: Path) -> Any:
|
||||
return tools_dict
|
||||
|
||||
|
||||
def _parse_param_schema(tool_xml: str) -> dict[str, Any]:
|
||||
params: set[str] = set()
|
||||
required: set[str] = set()
|
||||
|
||||
params_start = tool_xml.find("<parameters>")
|
||||
params_end = tool_xml.find("</parameters>")
|
||||
|
||||
if params_start == -1 or params_end == -1:
|
||||
return {"params": set(), "required": set(), "has_params": False}
|
||||
|
||||
params_section = tool_xml[params_start : params_end + len("</parameters>")]
|
||||
|
||||
try:
|
||||
root = DefusedET.fromstring(params_section)
|
||||
except DefusedET.ParseError:
|
||||
return {"params": set(), "required": set(), "has_params": False}
|
||||
|
||||
for param in root.findall(".//parameter"):
|
||||
name = param.attrib.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params.add(name)
|
||||
if param.attrib.get("required", "false").lower() == "true":
|
||||
required.add(name)
|
||||
|
||||
return {"params": params, "required": required, "has_params": bool(params or required)}
|
||||
|
||||
|
||||
def _get_module_name(func: Callable[..., Any]) -> str:
|
||||
module = inspect.getmodule(func)
|
||||
if not module:
|
||||
@@ -95,6 +128,27 @@ def _get_module_name(func: Callable[..., Any]) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _get_schema_path(func: Callable[..., Any]) -> Path | None:
|
||||
module = inspect.getmodule(func)
|
||||
if not module or not module.__name__:
|
||||
return None
|
||||
|
||||
module_name = module.__name__
|
||||
|
||||
if ".tools." not in module_name:
|
||||
return None
|
||||
|
||||
parts = module_name.split(".tools.")[-1].split(".")
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
|
||||
folder = parts[0]
|
||||
file_stem = parts[1]
|
||||
schema_file = f"{file_stem}_schema.xml"
|
||||
|
||||
return get_strix_resource_path("tools", folder, schema_file)
|
||||
|
||||
|
||||
def register_tool(
|
||||
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
|
||||
) -> Callable[..., Any]:
|
||||
@@ -109,11 +163,8 @@ def register_tool(
|
||||
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||
if not sandbox_mode:
|
||||
try:
|
||||
module_path = Path(inspect.getfile(f))
|
||||
schema_file_name = f"{module_path.stem}_schema.xml"
|
||||
schema_path = module_path.parent / schema_file_name
|
||||
|
||||
xml_tools = _load_xml_schema(schema_path)
|
||||
schema_path = _get_schema_path(f)
|
||||
xml_tools = _load_xml_schema(schema_path) if schema_path else None
|
||||
|
||||
if xml_tools is not None and f.__name__ in xml_tools:
|
||||
func_dict["xml_schema"] = xml_tools[f.__name__]
|
||||
@@ -131,6 +182,11 @@ def register_tool(
|
||||
"</tool>"
|
||||
)
|
||||
|
||||
if not sandbox_mode:
|
||||
xml_schema = func_dict.get("xml_schema")
|
||||
param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "")
|
||||
_tool_param_schemas[str(func_dict["name"])] = param_schema
|
||||
|
||||
tools.append(func_dict)
|
||||
_tools_by_name[str(func_dict["name"])] = f
|
||||
|
||||
@@ -153,6 +209,10 @@ def get_tool_names() -> list[str]:
|
||||
return list(_tools_by_name.keys())
|
||||
|
||||
|
||||
def get_tool_param_schema(name: str) -> dict[str, Any] | None:
|
||||
return _tool_param_schemas.get(name)
|
||||
|
||||
|
||||
def needs_agent_state(tool_name: str) -> bool:
|
||||
tool_func = get_tool_by_name(tool_name)
|
||||
if not tool_func:
|
||||
@@ -194,3 +254,4 @@ def get_tools_prompt() -> str:
|
||||
def clear_registry() -> None:
|
||||
tools.clear()
|
||||
_tools_by_name.clear()
|
||||
_tool_param_schemas.clear()
|
||||
|
||||
Reference in New Issue
Block a user