fix(agent): fix tool schemas not retrieved on pyinstaller binary and validate tool call args
This commit is contained in:
@@ -310,7 +310,7 @@ CRITICAL RULES:
|
|||||||
0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses.
|
0. While active in the agent loop, EVERY message you output MUST be a single tool call. Do not send plain text-only responses.
|
||||||
1. One tool call per message
|
1. One tool call per message
|
||||||
2. Tool call must be last in message
|
2. Tool call must be last in message
|
||||||
3. End response after </function> tag. It's your stop word. Do not continue after it.
|
3. EVERY tool call MUST end with </function>. This is MANDATORY. Never omit the closing tag. The </function> tag is your stop word - end your response immediately after it.
|
||||||
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
|
4. Use ONLY the exact XML format shown above. NEVER use JSON/YAML/INI or any other syntax for tools or parameters.
|
||||||
5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
|
5. Tool names must match exactly the tool "name" defined (no module prefixes, dots, or variants).
|
||||||
- Correct: <function=think> ... </function>
|
- Correct: <function=think> ... </function>
|
||||||
@@ -331,6 +331,8 @@ SPRAYING EXECUTION NOTE:
|
|||||||
- When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload.
|
- When performing large payload sprays or fuzzing, encapsulate the entire spraying loop inside a single python or terminal tool call (e.g., a Python script using asyncio/aiohttp). Do not issue one tool call per payload.
|
||||||
- Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial
|
- Favor batch-mode CLI tools (sqlmap, ffuf, nuclei, zaproxy, arjun) where appropriate and check traffic via the proxy when beneficial
|
||||||
|
|
||||||
|
REMINDER: Always close each tool call with </function> before going into the next. Incomplete tool calls will fail.
|
||||||
|
|
||||||
{{ get_tools_prompt() }}
|
{{ get_tools_prompt() }}
|
||||||
</tool_usage>
|
</tool_usage>
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
@@ -18,6 +17,7 @@ from strix.llm import LLM, LLMConfig, LLMRequestFailedError
|
|||||||
from strix.llm.utils import clean_content
|
from strix.llm.utils import clean_content
|
||||||
from strix.runtime import SandboxInitializationError
|
from strix.runtime import SandboxInitializationError
|
||||||
from strix.tools import process_tool_invocations
|
from strix.tools import process_tool_invocations
|
||||||
|
from strix.utils.resource_paths import get_strix_resource_path
|
||||||
|
|
||||||
from .state import AgentState
|
from .state import AgentState
|
||||||
|
|
||||||
@@ -35,8 +35,7 @@ class AgentMeta(type):
|
|||||||
if name == "BaseAgent":
|
if name == "BaseAgent":
|
||||||
return new_cls
|
return new_cls
|
||||||
|
|
||||||
agents_dir = Path(__file__).parent
|
prompt_dir = get_strix_resource_path("agents", name)
|
||||||
prompt_dir = agents_dir / name
|
|
||||||
|
|
||||||
new_cls.agent_name = name
|
new_cls.agent_name = name
|
||||||
new_cls.jinja_env = Environment(
|
new_cls.jinja_env = Environment(
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import logging
|
|||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
@@ -19,9 +18,14 @@ from strix.config import Config
|
|||||||
from strix.llm.config import LLMConfig
|
from strix.llm.config import LLMConfig
|
||||||
from strix.llm.memory_compressor import MemoryCompressor
|
from strix.llm.memory_compressor import MemoryCompressor
|
||||||
from strix.llm.request_queue import get_global_queue
|
from strix.llm.request_queue import get_global_queue
|
||||||
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
|
from strix.llm.utils import (
|
||||||
|
_truncate_to_first_function,
|
||||||
|
fix_incomplete_tool_call,
|
||||||
|
parse_tool_invocations,
|
||||||
|
)
|
||||||
from strix.skills import load_skills
|
from strix.skills import load_skills
|
||||||
from strix.tools import get_tools_prompt
|
from strix.tools import get_tools_prompt
|
||||||
|
from strix.utils.resource_paths import get_strix_resource_path
|
||||||
|
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
@@ -124,8 +128,8 @@ class LLM:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if agent_name:
|
if agent_name:
|
||||||
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
|
prompt_dir = get_strix_resource_path("agents", agent_name)
|
||||||
skills_dir = Path(__file__).parent.parent / "skills"
|
skills_dir = get_strix_resource_path("skills")
|
||||||
|
|
||||||
loader = FileSystemLoader([prompt_dir, skills_dir])
|
loader = FileSystemLoader([prompt_dir, skills_dir])
|
||||||
self.jinja_env = Environment(
|
self.jinja_env = Environment(
|
||||||
@@ -298,6 +302,8 @@ class LLM:
|
|||||||
function_end = accumulated_content.find("</function>") + len("</function>")
|
function_end = accumulated_content.find("</function>") + len("</function>")
|
||||||
accumulated_content = accumulated_content[:function_end]
|
accumulated_content = accumulated_content[:function_end]
|
||||||
|
|
||||||
|
accumulated_content = fix_incomplete_tool_call(accumulated_content)
|
||||||
|
|
||||||
tool_invocations = parse_tool_invocations(accumulated_content)
|
tool_invocations = parse_tool_invocations(accumulated_content)
|
||||||
|
|
||||||
# Extract thinking blocks from the complete response if available
|
# Extract thinking blocks from the complete response if available
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ def _truncate_to_first_function(content: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
||||||
content = _fix_stopword(content)
|
content = fix_incomplete_tool_call(content)
|
||||||
|
|
||||||
tool_invocations: list[dict[str, Any]] = []
|
tool_invocations: list[dict[str, Any]] = []
|
||||||
|
|
||||||
@@ -46,16 +46,15 @@ def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
|||||||
return tool_invocations if tool_invocations else None
|
return tool_invocations if tool_invocations else None
|
||||||
|
|
||||||
|
|
||||||
def _fix_stopword(content: str) -> str:
|
def fix_incomplete_tool_call(content: str) -> str:
|
||||||
|
"""Fix incomplete tool calls by adding missing </function> tag."""
|
||||||
if (
|
if (
|
||||||
"<function=" in content
|
"<function=" in content
|
||||||
and content.count("<function=") == 1
|
and content.count("<function=") == 1
|
||||||
and "</function>" not in content
|
and "</function>" not in content
|
||||||
):
|
):
|
||||||
if content.endswith("</"):
|
content = content.rstrip()
|
||||||
content = content.rstrip() + "function>"
|
content = content + "function>" if content.endswith("</") else content + "\n</function>"
|
||||||
else:
|
|
||||||
content = content + "\n</function>"
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
@@ -74,7 +73,7 @@ def clean_content(content: str) -> str:
|
|||||||
if not content:
|
if not content:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
content = _fix_stopword(content)
|
content = fix_incomplete_tool_call(content)
|
||||||
|
|
||||||
tool_pattern = r"<function=[^>]+>.*?</function>"
|
tool_pattern = r"<function=[^>]+>.*?</function>"
|
||||||
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
|
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from jinja2 import Environment
|
from jinja2 import Environment
|
||||||
|
|
||||||
|
from strix.utils.resource_paths import get_strix_resource_path
|
||||||
|
|
||||||
|
|
||||||
def get_available_skills() -> dict[str, list[str]]:
|
def get_available_skills() -> dict[str, list[str]]:
|
||||||
skills_dir = Path(__file__).parent
|
skills_dir = get_strix_resource_path("skills")
|
||||||
available_skills = {}
|
available_skills: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
if not skills_dir.exists():
|
||||||
|
return available_skills
|
||||||
|
|
||||||
for category_dir in skills_dir.iterdir():
|
for category_dir in skills_dir.iterdir():
|
||||||
if category_dir.is_dir() and not category_dir.name.startswith("__"):
|
if category_dir.is_dir() and not category_dir.name.startswith("__"):
|
||||||
@@ -72,7 +75,7 @@ def load_skills(skill_names: list[str], jinja_env: Environment) -> dict[str, str
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
skill_content = {}
|
skill_content = {}
|
||||||
skills_dir = Path(__file__).parent
|
skills_dir = get_strix_resource_path("skills")
|
||||||
|
|
||||||
available_skills = get_available_skills()
|
available_skills = get_available_skills()
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from .argument_parser import convert_arguments
|
|||||||
from .registry import (
|
from .registry import (
|
||||||
get_tool_by_name,
|
get_tool_by_name,
|
||||||
get_tool_names,
|
get_tool_names,
|
||||||
|
get_tool_param_schema,
|
||||||
needs_agent_state,
|
needs_agent_state,
|
||||||
should_execute_in_sandbox,
|
should_execute_in_sandbox,
|
||||||
)
|
)
|
||||||
@@ -110,14 +111,51 @@ async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwarg
|
|||||||
|
|
||||||
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
|
def validate_tool_availability(tool_name: str | None) -> tuple[bool, str]:
|
||||||
if tool_name is None:
|
if tool_name is None:
|
||||||
return False, "Tool name is missing"
|
available = ", ".join(sorted(get_tool_names()))
|
||||||
|
return False, f"Tool name is missing. Available tools: {available}"
|
||||||
|
|
||||||
if tool_name not in get_tool_names():
|
if tool_name not in get_tool_names():
|
||||||
return False, f"Tool '{tool_name}' is not available"
|
available = ", ".join(sorted(get_tool_names()))
|
||||||
|
return False, f"Tool '{tool_name}' is not available. Available tools: {available}"
|
||||||
|
|
||||||
return True, ""
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_tool_arguments(tool_name: str, kwargs: dict[str, Any]) -> str | None:
|
||||||
|
param_schema = get_tool_param_schema(tool_name)
|
||||||
|
if not param_schema or not param_schema.get("has_params"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
allowed_params: set[str] = param_schema.get("params", set())
|
||||||
|
required_params: set[str] = param_schema.get("required", set())
|
||||||
|
optional_params = allowed_params - required_params
|
||||||
|
|
||||||
|
schema_hint = _format_schema_hint(tool_name, required_params, optional_params)
|
||||||
|
|
||||||
|
unknown_params = set(kwargs.keys()) - allowed_params
|
||||||
|
if unknown_params:
|
||||||
|
unknown_list = ", ".join(sorted(unknown_params))
|
||||||
|
return f"Tool '{tool_name}' received unknown parameter(s): {unknown_list}\n{schema_hint}"
|
||||||
|
|
||||||
|
missing_required = [
|
||||||
|
param for param in required_params if param not in kwargs or kwargs.get(param) in (None, "")
|
||||||
|
]
|
||||||
|
if missing_required:
|
||||||
|
missing_list = ", ".join(sorted(missing_required))
|
||||||
|
return f"Tool '{tool_name}' missing required parameter(s): {missing_list}\n{schema_hint}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _format_schema_hint(tool_name: str, required: set[str], optional: set[str]) -> str:
|
||||||
|
parts = [f"Valid parameters for '{tool_name}':"]
|
||||||
|
if required:
|
||||||
|
parts.append(f" Required: {', '.join(sorted(required))}")
|
||||||
|
if optional:
|
||||||
|
parts.append(f" Optional: {', '.join(sorted(optional))}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool_with_validation(
|
async def execute_tool_with_validation(
|
||||||
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
|
tool_name: str | None, agent_state: Any | None = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@@ -127,6 +165,10 @@ async def execute_tool_with_validation(
|
|||||||
|
|
||||||
assert tool_name is not None
|
assert tool_name is not None
|
||||||
|
|
||||||
|
arg_error = _validate_tool_arguments(tool_name, kwargs)
|
||||||
|
if arg_error:
|
||||||
|
return f"Error: {arg_error}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await execute_tool(tool_name, agent_state, **kwargs)
|
result = await execute_tool(tool_name, agent_state, **kwargs)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e: # noqa: BLE001
|
||||||
|
|||||||
@@ -7,9 +7,14 @@ from inspect import signature
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import defusedxml.ElementTree as DefusedET
|
||||||
|
|
||||||
|
from strix.utils.resource_paths import get_strix_resource_path
|
||||||
|
|
||||||
|
|
||||||
tools: list[dict[str, Any]] = []
|
tools: list[dict[str, Any]] = []
|
||||||
_tools_by_name: dict[str, Callable[..., Any]] = {}
|
_tools_by_name: dict[str, Callable[..., Any]] = {}
|
||||||
|
_tool_param_schemas: dict[str, dict[str, Any]] = {}
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,6 +87,34 @@ def _load_xml_schema(path: Path) -> Any:
|
|||||||
return tools_dict
|
return tools_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_param_schema(tool_xml: str) -> dict[str, Any]:
|
||||||
|
params: set[str] = set()
|
||||||
|
required: set[str] = set()
|
||||||
|
|
||||||
|
params_start = tool_xml.find("<parameters>")
|
||||||
|
params_end = tool_xml.find("</parameters>")
|
||||||
|
|
||||||
|
if params_start == -1 or params_end == -1:
|
||||||
|
return {"params": set(), "required": set(), "has_params": False}
|
||||||
|
|
||||||
|
params_section = tool_xml[params_start : params_end + len("</parameters>")]
|
||||||
|
|
||||||
|
try:
|
||||||
|
root = DefusedET.fromstring(params_section)
|
||||||
|
except DefusedET.ParseError:
|
||||||
|
return {"params": set(), "required": set(), "has_params": False}
|
||||||
|
|
||||||
|
for param in root.findall(".//parameter"):
|
||||||
|
name = param.attrib.get("name")
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
params.add(name)
|
||||||
|
if param.attrib.get("required", "false").lower() == "true":
|
||||||
|
required.add(name)
|
||||||
|
|
||||||
|
return {"params": params, "required": required, "has_params": bool(params or required)}
|
||||||
|
|
||||||
|
|
||||||
def _get_module_name(func: Callable[..., Any]) -> str:
|
def _get_module_name(func: Callable[..., Any]) -> str:
|
||||||
module = inspect.getmodule(func)
|
module = inspect.getmodule(func)
|
||||||
if not module:
|
if not module:
|
||||||
@@ -95,6 +128,27 @@ def _get_module_name(func: Callable[..., Any]) -> str:
|
|||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_schema_path(func: Callable[..., Any]) -> Path | None:
|
||||||
|
module = inspect.getmodule(func)
|
||||||
|
if not module or not module.__name__:
|
||||||
|
return None
|
||||||
|
|
||||||
|
module_name = module.__name__
|
||||||
|
|
||||||
|
if ".tools." not in module_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts = module_name.split(".tools.")[-1].split(".")
|
||||||
|
if len(parts) < 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
folder = parts[0]
|
||||||
|
file_stem = parts[1]
|
||||||
|
schema_file = f"{file_stem}_schema.xml"
|
||||||
|
|
||||||
|
return get_strix_resource_path("tools", folder, schema_file)
|
||||||
|
|
||||||
|
|
||||||
def register_tool(
|
def register_tool(
|
||||||
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
|
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
|
||||||
) -> Callable[..., Any]:
|
) -> Callable[..., Any]:
|
||||||
@@ -109,11 +163,8 @@ def register_tool(
|
|||||||
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||||
if not sandbox_mode:
|
if not sandbox_mode:
|
||||||
try:
|
try:
|
||||||
module_path = Path(inspect.getfile(f))
|
schema_path = _get_schema_path(f)
|
||||||
schema_file_name = f"{module_path.stem}_schema.xml"
|
xml_tools = _load_xml_schema(schema_path) if schema_path else None
|
||||||
schema_path = module_path.parent / schema_file_name
|
|
||||||
|
|
||||||
xml_tools = _load_xml_schema(schema_path)
|
|
||||||
|
|
||||||
if xml_tools is not None and f.__name__ in xml_tools:
|
if xml_tools is not None and f.__name__ in xml_tools:
|
||||||
func_dict["xml_schema"] = xml_tools[f.__name__]
|
func_dict["xml_schema"] = xml_tools[f.__name__]
|
||||||
@@ -131,6 +182,11 @@ def register_tool(
|
|||||||
"</tool>"
|
"</tool>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not sandbox_mode:
|
||||||
|
xml_schema = func_dict.get("xml_schema")
|
||||||
|
param_schema = _parse_param_schema(xml_schema if isinstance(xml_schema, str) else "")
|
||||||
|
_tool_param_schemas[str(func_dict["name"])] = param_schema
|
||||||
|
|
||||||
tools.append(func_dict)
|
tools.append(func_dict)
|
||||||
_tools_by_name[str(func_dict["name"])] = f
|
_tools_by_name[str(func_dict["name"])] = f
|
||||||
|
|
||||||
@@ -153,6 +209,10 @@ def get_tool_names() -> list[str]:
|
|||||||
return list(_tools_by_name.keys())
|
return list(_tools_by_name.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_param_schema(name: str) -> dict[str, Any] | None:
|
||||||
|
return _tool_param_schemas.get(name)
|
||||||
|
|
||||||
|
|
||||||
def needs_agent_state(tool_name: str) -> bool:
|
def needs_agent_state(tool_name: str) -> bool:
|
||||||
tool_func = get_tool_by_name(tool_name)
|
tool_func = get_tool_by_name(tool_name)
|
||||||
if not tool_func:
|
if not tool_func:
|
||||||
@@ -194,3 +254,4 @@ def get_tools_prompt() -> str:
|
|||||||
def clear_registry() -> None:
|
def clear_registry() -> None:
|
||||||
tools.clear()
|
tools.clear()
|
||||||
_tools_by_name.clear()
|
_tools_by_name.clear()
|
||||||
|
_tool_param_schemas.clear()
|
||||||
|
|||||||
0
strix/utils/__init__.py
Normal file
0
strix/utils/__init__.py
Normal file
13
strix/utils/resource_paths.py
Normal file
13
strix/utils/resource_paths.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_strix_resource_path(*parts: str) -> Path:
|
||||||
|
frozen_base = getattr(sys, "_MEIPASS", None)
|
||||||
|
if frozen_base:
|
||||||
|
base = Path(frozen_base) / "strix"
|
||||||
|
if base.exists():
|
||||||
|
return base.joinpath(*parts)
|
||||||
|
|
||||||
|
base = Path(__file__).resolve().parent.parent
|
||||||
|
return base.joinpath(*parts)
|
||||||
Reference in New Issue
Block a user