fix(agent): fix tool schemas not retrieved on pyinstaller binary and validate tool call args

This commit is contained in:
0xallam
2026-01-14 10:34:40 -08:00
committed by Ahmed Allam
parent bc8e14f68a
commit f08014cf51
9 changed files with 152 additions and 27 deletions

View File

@@ -310,7 +310,7 @@ CRITICAL RULES:
0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses. 0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses.
1. One tool call per message 1. One tool call per message
2. Tool call must be last in message 2. Tool call must be last in message
3. End response after </function> tag. It's your stop word. Do not continue after it. 3. EVERY tool call MUST end with </function>. This is MANDATORY. Never omit the closing tag. The </function> tag is your stop word - end your response immediately after it.
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters. 4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants). 5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
- Correct: <function=think> ... </function> - Correct: <function=think> ... </function>
@@ -331,6 +331,8 @@ SPRAYING EXECUTION NOTE:
- When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload. - When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload.
- Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial - Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial
REMINDER: Always close each tool call with </function> before going into the next. Incomplete tool calls will fail.
{{ get_tools_prompt() }} {{ get_tools_prompt() }}
</tool_usage> </tool_usage>

View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import contextlib import contextlib
import logging import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@@ -18,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError
from strix.llm.utils import clean_content from strix.llm.utils import clean_content
from strix.runtime import SandboxInitializationError from strix.runtime import SandboxInitializationError
from strix.tools import process_tool_invocations from strix.tools import process_tool_invocations
from strix.utils.resource_paths import get_strix_resource_path
from .state import AgentState from .state import AgentState
@@ -35,8 +35,7 @@ class AgentMeta(type):
if name == "BaseAgent": if name == "BaseAgent":
return new_cls return new_cls
agents_dir = Path(__file__).parent prompt_dir = get_strix_resource_path("agents", name)
prompt_dir = agents_dir / name
new_cls.agent_name = name new_cls.agent_name = name
new_cls.jinja_env = Environment( new_cls.jinja_env = Environment(

View File

@@ -3,7 +3,6 @@ import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any from typing import Any
import litellm import litellm
@@ -19,9 +18,14 @@ from strix.config import Config
from strix.llm.config import LLMConfig from strix.llm.config import LLMConfig
from strix.llm.memory_compressor import MemoryCompressor from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.request_queue import get_global_queue from strix.llm.request_queue import get_global_queue
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations from strix.llm.utils import (
_truncate_to_first_function,
fix_incomplete_tool_call,
parse_tool_invocations,
)
from strix.skills import load_skills from strix.skills import load_skills
from strix.tools import get_tools_prompt from strix.tools import get_tools_prompt
from strix.utils.resource_paths import get_strix_resource_path
MAX_RETRIES = 5 MAX_RETRIES = 5
@@ -124,8 +128,8 @@ class LLM:
) )
if agent_name: if agent_name:
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name prompt_dir = get_strix_resource_path("agents", agent_name)
skills_dir = Path(__file__).parent.parent / "skills" skills_dir = get_strix_resource_path("skills")
loader = FileSystemLoader([prompt_dir, skills_dir]) loader = FileSystemLoader([prompt_dir, skills_dir])
self.jinja_env = Environment( self.jinja_env = Environment(
@@ -298,6 +302,8 @@ class LLM:
function_end = accumulated_content.find("</function>") + len("</function>") function_end = accumulated_content.find("</function>") + len("</function>")
accumulated_content = accumulated_content[:function_end] accumulated_content = accumulated_content[:function_end]
accumulated_content = fix_incomplete_tool_call(accumulated_content)
tool_invocations = parse_tool_invocations(accumulated_content) tool_invocations = parse_tool_invocations(accumulated_content)
# Extract thinking blocks from the complete response if available # Extract thinking blocks from the complete response if available

View File

@@ -18,7 +18,7 @@ def _truncate_to_first_function(content: str) -> str:
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None: def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
content = _fix_stopword(content) content = fix_incomplete_tool_call(content)
tool_invocations: list[dict[str, Any]] = [] tool_invocations: list[dict[str, Any]] = []
@@ -46,16 +46,15 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
return tool_invocations if tool_invocations else None return tool_invocations if tool_invocations else None
def _fix_stopword(content: str) -> str: def fix_incomplete_tool_call(content: str) -> str:
"""Fix incomplete tool calls by adding missing </function> tag."""
if ( if (
"<function=" in content "<function=" in content
and content.count("<function=") == 1 and content.count("<function=") == 1
and "</function>" not in content and "</function>" not in content
): ):
if content.endswith("</"): content = content.rstrip()
content = content.rstrip() + "function>" content = content + "function>" if content.endswith("</") else content + "\n</function>"
else:
content = content + "\n</function>"
return content return content
@@ -74,7 +73,7 @@ def clean_content(content: str) -> str:
if not content: if not content:
return "" return ""
content = _fix_stopword(content) content = fix_incomplete_tool_call(content)
tool_pattern = r"<function=[^>]+>.*?</function>" tool_pattern = r"<function=[^>]+>.*?</function>"
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL) cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)

View File

@@ -1,11 +1,14 @@
from pathlib import Path
from jinja2 import Environment from jinja2 import Environment
from strix.utils.resource_paths import get_strix_resource_path
def get_available_skills() -> dict[str, list[str]]: def get_available_skills() -> dict[str, list[str]]:
skills_dir = Path(__file__).parent skills_dir = get_strix_resource_path("skills")
available_skills = {} available_skills: dict[str, list[str]] = {}
if not skills_dir.exists():
return available_skills
for category_dir in skills_dir.iterdir(): for category_dir in skills_dir.iterdir():
if category_dir.is_dir() and not category_dir.name.startswith("__"): if category_dir.is_dir() and not category_dir.name.startswith("__"):
@@ -72,7 +75,7 @@ def load_skills(skill_names: list[str], jinja_env: Environment) -> dict[str, str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
skill_content = {} skill_content = {}
skills_dir = Path(__file__).parent skills_dir = get_strix_resource_path("skills")
available_skills = get_available_skills() available_skills = get_available_skills()

View File

@@ -14,6 +14,7 @@ from .argument_parser import convert_arguments
from .registry import ( from .registry import (
get_tool_by_name, get_tool_by_name,
get_tool_names, get_tool_names,
get_tool_param_schema,
needs_agent_state, needs_agent_state,
should_execute_in_sandbox, should_execute_in_sandbox,
) )
@@ -110,14 +111,51 @@ async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwarg
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]: def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
if tool_name is None: if tool_name is None:
return False, "Tool name is missing" available = ", ".join(sorted(get_tool_names()))
return False, f"Tool name is missing. Available tools: {available}"
if tool_name not in get_tool_names(): if tool_name not in get_tool_names():
return False, f"Tool '{tool_name}' is not available" available = ", ".join(sorted(get_tool_names()))
return False, f"Tool '{tool_name}' is not available. Available tools: {available}"
return True, "" return True, ""
def _validate_tool_arguments(tool_name: str, kwargs: dict[str, Any]) -> str | None:
param_schema = get_tool_param_schema(tool_name)
if not param_schema or not param_schema.get("has_params"):
return None
allowed_params: set[str] = param_schema.get("params", set())
required_params: set[str] = param_schema.get("required", set())
optional_params = allowed_params - required_params
schema_hint = _format_schema_hint(tool_name, required_params, optional_params)
unknown_params = set(kwargs.keys()) - allowed_params
if unknown_params:
unknown_list = ", ".join(sorted(unknown_params))
return f"Tool '{tool_name}' received unknown parameter(s): {unknown_list}\n{schema_hint}"
missing_required = [
param for param in required_params if param not in kwargs or kwargs.get(param) in (None, "")
]
if missing_required:
missing_list = ", ".join(sorted(missing_required))
return f"Tool '{tool_name}' missing required parameter(s): {missing_list}\n{schema_hint}"
return None
def _format_schema_hint(tool_name: str, required: set[str], optional: set[str]) -> str:
parts = [f"Valid parameters for '{tool_name}':"]
if required:
parts.append(f" Required: {', '.join(sorted(required))}")
if optional:
parts.append(f" Optional: {', '.join(sorted(optional))}")
return "\n".join(parts)
async def execute_tool_with_validation( async def execute_tool_with_validation(
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
) -> Any: ) -> Any:
@@ -127,6 +165,10 @@ async def execute_tool_with_validation(
assert tool_name is not None assert tool_name is not None
arg_error = _validate_tool_arguments(tool_name, kwargs)
if arg_error:
return f"Error: {arg_error}"
try: try:
result = await execute_tool(tool_name, agent_state, **kwargs) result = await execute_tool(tool_name, agent_state, **kwargs)
except Exception as e: # noqa: BLE001 except Exception as e: # noqa: BLE001

View File

@@ -7,9 +7,14 @@ from inspect import signature
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import defusedxml.ElementTree as DefusedET
from strix.utils.resource_paths import get_strix_resource_path
tools: list[dict[str, Any]] = [] tools: list[dict[str, Any]] = []
_tools_by_name: dict[str, Callable[..., Any]] = {} _tools_by_name: dict[str, Callable[..., Any]] = {}
_tool_param_schemas: dict[str, dict[str, Any]] = {}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -82,6 +87,34 @@ def _load_xml_schema(path: Path) -> Any:
return tools_dict return tools_dict
def _parse_param_schema(tool_xml: str) -> dict[str, Any]:
params: set[str] = set()
required: set[str] = set()
params_start = tool_xml.find("<parameters>")
params_end = tool_xml.find("</parameters>")
if params_start == -1 or params_end == -1:
return {"params": set(), "required": set(), "has_params": False}
params_section = tool_xml[params_start : params_end + len("</parameters>")]
try:
root = DefusedET.fromstring(params_section)
except DefusedET.ParseError:
return {"params": set(), "required": set(), "has_params": False}
for param in root.findall(".//parameter"):
name = param.attrib.get("name")
if not name:
continue
params.add(name)
if param.attrib.get("required", "false").lower() == "true":
required.add(name)
return {"params": params, "required": required, "has_params": bool(params or required)}
def _get_module_name(func: Callable[..., Any]) -> str: def _get_module_name(func: Callable[..., Any]) -> str:
module = inspect.getmodule(func) module = inspect.getmodule(func)
if not module: if not module:
@@ -95,6 +128,27 @@ def _get_module_name(func: Callable[..., Any]) -> str:
return "unknown" return "unknown"
def _get_schema_path(func: Callable[..., Any]) -> Path | None:
module = inspect.getmodule(func)
if not module or not module.__name__:
return None
module_name = module.__name__
if ".tools." not in module_name:
return None
parts = module_name.split(".tools.")[-1].split(".")
if len(parts) < 2:
return None
folder = parts[0]
file_stem = parts[1]
schema_file = f"{file_stem}_schema.xml"
return get_strix_resource_path("tools", folder, schema_file)
def register_tool( def register_tool(
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
) -> Callable[..., Any]: ) -> Callable[..., Any]:
@@ -109,11 +163,8 @@ def register_tool(
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
if not sandbox_mode: if not sandbox_mode:
try: try:
module_path = Path(inspect.getfile(f)) schema_path = _get_schema_path(f)
schema_file_name = f"{module_path.stem}_schema.xml" xml_tools = _load_xml_schema(schema_path) if schema_path else None
schema_path = module_path.parent / schema_file_name
xml_tools = _load_xml_schema(schema_path)
if xml_tools is not None and f.__name__ in xml_tools: if xml_tools is not None and f.__name__ in xml_tools:
func_dict["xml_schema"] = xml_tools[f.__name__] func_dict["xml_schema"] = xml_tools[f.__name__]
@@ -131,6 +182,11 @@ def register_tool(
"</tool>" "</tool>"
) )
if not sandbox_mode:
xml_schema = func_dict.get("xml_schema")
param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "")
_tool_param_schemas[str(func_dict["name"])] = param_schema
tools.append(func_dict) tools.append(func_dict)
_tools_by_name[str(func_dict["name"])] = f _tools_by_name[str(func_dict["name"])] = f
@@ -153,6 +209,10 @@ def get_tool_names() -> list[str]:
return list(_tools_by_name.keys()) return list(_tools_by_name.keys())
def get_tool_param_schema(name: str) -> dict[str, Any] | None:
return _tool_param_schemas.get(name)
def needs_agent_state(tool_name: str) -> bool: def needs_agent_state(tool_name: str) -> bool:
tool_func = get_tool_by_name(tool_name) tool_func = get_tool_by_name(tool_name)
if not tool_func: if not tool_func:
@@ -194,3 +254,4 @@ def get_tools_prompt() -> str:
def clear_registry() -> None: def clear_registry() -> None:
tools.clear() tools.clear()
_tools_by_name.clear() _tools_by_name.clear()
_tool_param_schemas.clear()

0
strix/utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,13 @@
import sys
from pathlib import Path
def get_strix_resource_path(*parts: str) -> Path:
frozen_base = getattr(sys, "_MEIPASS", None)
if frozen_base:
base = Path(frozen_base) / "strix"
if base.exists():
return base.joinpath(*parts)
base = Path(__file__).resolve().parent.parent
return base.joinpath(*parts)