From f08014cf51a127e068967b5fb1eece79ae88bf79 Mon Sep 17 00:00:00 2001 From: 0xallam Date: Wed, 14 Jan 2026 10:34:40 -0800 Subject: [PATCH] fix(agent): fix tool schemas not retrieved on pyinstaller binary and validate tool call args --- strix/agents/StrixAgent/system_prompt.jinja | 4 +- strix/agents/base_agent.py | 5 +- strix/llm/llm.py | 14 ++-- strix/llm/utils.py | 13 ++-- strix/skills/__init__.py | 13 ++-- strix/tools/executor.py | 46 ++++++++++++- strix/tools/registry.py | 71 +++++++++++++++++++-- strix/utils/__init__.py | 0 strix/utils/resource_paths.py | 13 ++++ 9 files changed, 152 insertions(+), 27 deletions(-) create mode 100644 strix/utils/__init__.py create mode 100644 strix/utils/resource_paths.py diff --git a/strix/agents/StrixAgent/system_prompt.jinja b/strix/agents/StrixAgent/system_prompt.jinja index 99c21b0..038e039 100644 --- a/strix/agents/StrixAgent/system_prompt.jinja +++ b/strix/agents/StrixAgent/system_prompt.jinja @@ -310,7 +310,7 @@ CRITICAL RULES: 0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses. 1. One tool call per message 2. Tool call must be last in message -3. End response after tag. It's your stop word. Do not continue after it. +3. EVERY tool call MUST end with . This is MANDATORY. Never omit the closing tag. The tag is your stop word - end your response immediately after it. 4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters. 5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants). - Correct: ... @@ -331,6 +331,8 @@ SPRAYING EXECUTION NOTE: - When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload. - Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial +REMINDER: Always close each tool call with before going into the next. Incomplete tool calls will fail. + {{ get_tools_prompt() }} diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 3763541..ded04c6 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -1,7 +1,6 @@ import asyncio import contextlib import logging -from pathlib import Path from typing import TYPE_CHECKING, Any, Optional @@ -18,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError from strix.llm.utils import clean_content from strix.runtime import SandboxInitializationError from strix.tools import process_tool_invocations +from strix.utils.resource_paths import get_strix_resource_path from .state import AgentState @@ -35,8 +35,7 @@ class AgentMeta(type): if name == "BaseAgent": return new_cls - agents_dir = Path(__file__).parent - prompt_dir = agents_dir / name + prompt_dir = get_strix_resource_path("agents", name) new_cls.agent_name = name new_cls.jinja_env = Environment( diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 3bbc27b..34132a8 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -3,7 +3,6 @@ import logging from collections.abc import AsyncIterator from dataclasses import dataclass from enum import Enum -from pathlib import Path from typing import Any import litellm @@ -19,9 +18,14 @@ from strix.config import Config from strix.llm.config import LLMConfig from strix.llm.memory_compressor import MemoryCompressor from strix.llm.request_queue import get_global_queue -from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations +from strix.llm.utils import ( + _truncate_to_first_function, + fix_incomplete_tool_call, + parse_tool_invocations, +) from strix.skills import load_skills from strix.tools import get_tools_prompt +from strix.utils.resource_paths import get_strix_resource_path MAX_RETRIES = 5 @@ -124,8 +128,8 @@ class LLM: ) if agent_name: - prompt_dir = Path(__file__).parent.parent / "agents" / agent_name - skills_dir = Path(__file__).parent.parent / "skills" + prompt_dir = get_strix_resource_path("agents", agent_name) + skills_dir = get_strix_resource_path("skills") loader = FileSystemLoader([prompt_dir, skills_dir]) self.jinja_env = Environment( @@ -298,6 +302,8 @@ class LLM: function_end = accumulated_content.find("") + len("") accumulated_content = accumulated_content[:function_end] + accumulated_content = fix_incomplete_tool_call(accumulated_content) + tool_invocations = parse_tool_invocations(accumulated_content) # Extract thinking blocks from the complete response if available diff --git a/strix/llm/utils.py b/strix/llm/utils.py index e775cff..81431f0 100644 --- a/strix/llm/utils.py +++ b/strix/llm/utils.py @@ -18,7 +18,7 @@ def _truncate_to_first_function(content: str) -> str: def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None: - content = _fix_stopword(content) + content = fix_incomplete_tool_call(content) tool_invocations: list[dict[str, Any]] = [] @@ -46,16 +46,15 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None: return tool_invocations if tool_invocations else None -def _fix_stopword(content: str) -> str: +def fix_incomplete_tool_call(content: str) -> str: + """Fix incomplete tool calls by adding missing tag.""" if ( "" not in content ): - if content.endswith("" - else: - content = content + "\n" + content = content.rstrip() + content = content + "function>" if content.endswith("" return content @@ -74,7 +73,7 @@ def clean_content(content: str) -> str: if not content: return "" - content = _fix_stopword(content) + content = fix_incomplete_tool_call(content) tool_pattern = r"]+>.*?" cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL) diff --git a/strix/skills/__init__.py b/strix/skills/__init__.py index 1d148f9..f75cb55 100644 --- a/strix/skills/__init__.py +++ b/strix/skills/__init__.py @@ -1,11 +1,14 @@ -from pathlib import Path - from jinja2 import Environment +from strix.utils.resource_paths import get_strix_resource_path + def get_available_skills() -> dict[str, list[str]]: - skills_dir = Path(__file__).parent - available_skills = {} + skills_dir = get_strix_resource_path("skills") + available_skills: dict[str, list[str]] = {} + + if not skills_dir.exists(): + return available_skills for category_dir in skills_dir.iterdir(): if category_dir.is_dir() and not category_dir.name.startswith("__"): @@ -72,7 +75,7 @@ def load_skills(skill_names: list[str], jinja_env: Environment) -> dict[str, str logger = logging.getLogger(__name__) skill_content = {} - skills_dir = Path(__file__).parent + skills_dir = get_strix_resource_path("skills") available_skills = get_available_skills() diff --git a/strix/tools/executor.py b/strix/tools/executor.py index 9dbc74e..ad0aeef 100644 --- a/strix/tools/executor.py +++ b/strix/tools/executor.py @@ -14,6 +14,7 @@ from .argument_parser import convert_arguments from .registry import ( get_tool_by_name, get_tool_names, + get_tool_param_schema, needs_agent_state, should_execute_in_sandbox, ) @@ -110,14 +111,51 @@ async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwarg def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]: if tool_name is None: - return False, "Tool name is missing" + available = ", ".join(sorted(get_tool_names())) + return False, f"Tool name is missing. Available tools: {available}" if tool_name not in get_tool_names(): - return False, f"Tool '{tool_name}' is not available" + available = ", ".join(sorted(get_tool_names())) + return False, f"Tool '{tool_name}' is not available. Available tools: {available}" return True, "" +def _validate_tool_arguments(tool_name: str, kwargs: dict[str, Any]) -> str | None: + param_schema = get_tool_param_schema(tool_name) + if not param_schema or not param_schema.get("has_params"): + return None + + allowed_params: set[str] = param_schema.get("params", set()) + required_params: set[str] = param_schema.get("required", set()) + optional_params = allowed_params - required_params + + schema_hint = _format_schema_hint(tool_name, required_params, optional_params) + + unknown_params = set(kwargs.keys()) - allowed_params + if unknown_params: + unknown_list = ", ".join(sorted(unknown_params)) + return f"Tool '{tool_name}' received unknown parameter(s): {unknown_list}\n{schema_hint}" + + missing_required = [ + param for param in required_params if param not in kwargs or kwargs.get(param) in (None, "") + ] + if missing_required: + missing_list = ", ".join(sorted(missing_required)) + return f"Tool '{tool_name}' missing required parameter(s): {missing_list}\n{schema_hint}" + + return None + + +def _format_schema_hint(tool_name: str, required: set[str], optional: set[str]) -> str: + parts = [f"Valid parameters for '{tool_name}':"] + if required: + parts.append(f" Required: {', '.join(sorted(required))}") + if optional: + parts.append(f" Optional: {', '.join(sorted(optional))}") + return "\n".join(parts) + + async def execute_tool_with_validation( tool_name: str | None, agent_state: Any | None = None, **kwargs: Any ) -> Any: @@ -127,6 +165,10 @@ async def execute_tool_with_validation( assert tool_name is not None + arg_error = _validate_tool_arguments(tool_name, kwargs) + if arg_error: + return f"Error: {arg_error}" + try: result = await execute_tool(tool_name, agent_state, **kwargs) except Exception as e: # noqa: BLE001 diff --git a/strix/tools/registry.py b/strix/tools/registry.py index 9ed50d4..ac6e15f 100644 --- a/strix/tools/registry.py +++ b/strix/tools/registry.py @@ -7,9 +7,14 @@ from inspect import signature from pathlib import Path from typing import Any +import defusedxml.ElementTree as DefusedET + +from strix.utils.resource_paths import get_strix_resource_path + tools: list[dict[str, Any]] = [] _tools_by_name: dict[str, Callable[..., Any]] = {} +_tool_param_schemas: dict[str, dict[str, Any]] = {} logger = logging.getLogger(__name__) @@ -82,6 +87,34 @@ def _load_xml_schema(path: Path) -> Any: return tools_dict +def _parse_param_schema(tool_xml: str) -> dict[str, Any]: + params: set[str] = set() + required: set[str] = set() + + params_start = tool_xml.find("") + params_end = tool_xml.find("") + + if params_start == -1 or params_end == -1: + return {"params": set(), "required": set(), "has_params": False} + + params_section = tool_xml[params_start : params_end + len("")] + + try: + root = DefusedET.fromstring(params_section) + except DefusedET.ParseError: + return {"params": set(), "required": set(), "has_params": False} + + for param in root.findall(".//parameter"): + name = param.attrib.get("name") + if not name: + continue + params.add(name) + if param.attrib.get("required", "false").lower() == "true": + required.add(name) + + return {"params": params, "required": required, "has_params": bool(params or required)} + + def _get_module_name(func: Callable[..., Any]) -> str: module = inspect.getmodule(func) if not module: @@ -95,6 +128,27 @@ def _get_module_name(func: Callable[..., Any]) -> str: return "unknown" +def _get_schema_path(func: Callable[..., Any]) -> Path | None: + module = inspect.getmodule(func) + if not module or not module.__name__: + return None + + module_name = module.__name__ + + if ".tools." not in module_name: + return None + + parts = module_name.split(".tools.")[-1].split(".") + if len(parts) < 2: + return None + + folder = parts[0] + file_stem = parts[1] + schema_file = f"{file_stem}_schema.xml" + + return get_strix_resource_path("tools", folder, schema_file) + + def register_tool( func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True ) -> Callable[..., Any]: @@ -109,11 +163,8 @@ def register_tool( sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" if not sandbox_mode: try: - module_path = Path(inspect.getfile(f)) - schema_file_name = f"{module_path.stem}_schema.xml" - schema_path = module_path.parent / schema_file_name - - xml_tools = _load_xml_schema(schema_path) + schema_path = _get_schema_path(f) + xml_tools = _load_xml_schema(schema_path) if schema_path else None if xml_tools is not None and f.__name__ in xml_tools: func_dict["xml_schema"] = xml_tools[f.__name__] @@ -131,6 +182,11 @@ def register_tool( "" ) + if not sandbox_mode: + xml_schema = func_dict.get("xml_schema") + param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "") + _tool_param_schemas[str(func_dict["name"])] = param_schema + tools.append(func_dict) _tools_by_name[str(func_dict["name"])] = f @@ -153,6 +209,10 @@ def get_tool_names() -> list[str]: return list(_tools_by_name.keys()) +def get_tool_param_schema(name: str) -> dict[str, Any] | None: + return _tool_param_schemas.get(name) + + def needs_agent_state(tool_name: str) -> bool: tool_func = get_tool_by_name(tool_name) if not tool_func: @@ -194,3 +254,4 @@ def get_tools_prompt() -> str: def clear_registry() -> None: tools.clear() _tools_by_name.clear() + _tool_param_schemas.clear() diff --git a/strix/utils/__init__.py b/strix/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strix/utils/resource_paths.py b/strix/utils/resource_paths.py new file mode 100644 index 0000000..6a0ab44 --- /dev/null +++ b/strix/utils/resource_paths.py @@ -0,0 +1,13 @@ +import sys +from pathlib import Path + + +def get_strix_resource_path(*parts: str) -> Path: + frozen_base = getattr(sys, "_MEIPASS", None) + if frozen_base: + base = Path(frozen_base) / "strix" + if base.exists(): + return base.joinpath(*parts) + + base = Path(__file__).resolve().parent.parent + return base.joinpath(*parts)