refactor: move tool availability checks into registration
This commit is contained in:
@@ -1,7 +1,5 @@
|
|||||||
import os
|
from .agents_graph import * # noqa: F403
|
||||||
|
from .browser import * # noqa: F403
|
||||||
from strix.config import Config
|
|
||||||
|
|
||||||
from .executor import (
|
from .executor import (
|
||||||
execute_tool,
|
execute_tool,
|
||||||
execute_tool_invocation,
|
execute_tool_invocation,
|
||||||
@@ -11,6 +9,12 @@ from .executor import (
|
|||||||
remove_screenshot_from_result,
|
remove_screenshot_from_result,
|
||||||
validate_tool_availability,
|
validate_tool_availability,
|
||||||
)
|
)
|
||||||
|
from .file_edit import * # noqa: F403
|
||||||
|
from .finish import * # noqa: F403
|
||||||
|
from .load_skill import * # noqa: F403
|
||||||
|
from .notes import * # noqa: F403
|
||||||
|
from .proxy import * # noqa: F403
|
||||||
|
from .python import * # noqa: F403
|
||||||
from .registry import (
|
from .registry import (
|
||||||
ImplementedInClientSideOnlyError,
|
ImplementedInClientSideOnlyError,
|
||||||
get_tool_by_name,
|
get_tool_by_name,
|
||||||
@@ -20,53 +24,13 @@ from .registry import (
|
|||||||
register_tool,
|
register_tool,
|
||||||
tools,
|
tools,
|
||||||
)
|
)
|
||||||
|
from .reporting import * # noqa: F403
|
||||||
|
from .terminal import * # noqa: F403
|
||||||
|
from .thinking import * # noqa: F403
|
||||||
|
from .todo import * # noqa: F403
|
||||||
|
from .web_search import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_browser_disabled() -> bool:
|
|
||||||
if os.getenv("STRIX_DISABLE_BROWSER", "").lower() == "true":
|
|
||||||
return True
|
|
||||||
val: str = Config.load().get("env", {}).get("STRIX_DISABLE_BROWSER", "")
|
|
||||||
return str(val).lower() == "true"
|
|
||||||
|
|
||||||
|
|
||||||
DISABLE_BROWSER = _is_browser_disabled()
|
|
||||||
|
|
||||||
|
|
||||||
def _has_perplexity_api() -> bool:
|
|
||||||
if os.getenv("PERPLEXITY_API_KEY"):
|
|
||||||
return True
|
|
||||||
return bool(Config.load().get("env", {}).get("PERPLEXITY_API_KEY"))
|
|
||||||
|
|
||||||
|
|
||||||
if not SANDBOX_MODE:
|
|
||||||
from .agents_graph import * # noqa: F403
|
|
||||||
|
|
||||||
if not DISABLE_BROWSER:
|
|
||||||
from .browser import * # noqa: F403
|
|
||||||
from .file_edit import * # noqa: F403
|
|
||||||
from .finish import * # noqa: F403
|
|
||||||
from .load_skill import * # noqa: F403
|
|
||||||
from .notes import * # noqa: F403
|
|
||||||
from .proxy import * # noqa: F403
|
|
||||||
from .python import * # noqa: F403
|
|
||||||
from .reporting import * # noqa: F403
|
|
||||||
from .terminal import * # noqa: F403
|
|
||||||
from .thinking import * # noqa: F403
|
|
||||||
from .todo import * # noqa: F403
|
|
||||||
|
|
||||||
if _has_perplexity_api():
|
|
||||||
from .web_search import * # noqa: F403
|
|
||||||
else:
|
|
||||||
if not DISABLE_BROWSER:
|
|
||||||
from .browser import * # noqa: F403
|
|
||||||
from .file_edit import * # noqa: F403
|
|
||||||
from .proxy import * # noqa: F403
|
|
||||||
from .python import * # noqa: F403
|
|
||||||
from .terminal import * # noqa: F403
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ImplementedInClientSideOnlyError",
|
"ImplementedInClientSideOnlyError",
|
||||||
"execute_tool",
|
"execute_tool",
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ def _handle_utility_actions(
|
|||||||
raise ValueError(f"Unknown utility action: {action}")
|
raise ValueError(f"Unknown utility action: {action}")
|
||||||
|
|
||||||
|
|
||||||
@register_tool
|
@register_tool(requires_browser_mode=True)
|
||||||
def browser_action(
|
def browser_action(
|
||||||
action: BrowserAction,
|
action: BrowserAction,
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
|
|||||||
@@ -149,10 +149,60 @@ def _get_schema_path(func: Callable[..., Any]) -> Path | None:
|
|||||||
return get_strix_resource_path("tools", folder, schema_file)
|
return get_strix_resource_path("tools", folder, schema_file)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_sandbox_mode() -> bool:
|
||||||
|
return os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_browser_disabled() -> bool:
|
||||||
|
if os.getenv("STRIX_DISABLE_BROWSER", "").lower() == "true":
|
||||||
|
return True
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
|
||||||
|
val: str = Config.load().get("env", {}).get("STRIX_DISABLE_BROWSER", "")
|
||||||
|
return str(val).lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
|
def _has_perplexity_api() -> bool:
|
||||||
|
if os.getenv("PERPLEXITY_API_KEY"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
|
||||||
|
return bool(Config.load().get("env", {}).get("PERPLEXITY_API_KEY"))
|
||||||
|
|
||||||
|
|
||||||
|
def _should_register_tool(
|
||||||
|
*,
|
||||||
|
sandbox_execution: bool,
|
||||||
|
requires_browser_mode: bool,
|
||||||
|
requires_web_search_mode: bool,
|
||||||
|
) -> bool:
|
||||||
|
sandbox_mode = _is_sandbox_mode()
|
||||||
|
|
||||||
|
if sandbox_mode and not sandbox_execution:
|
||||||
|
return False
|
||||||
|
if requires_browser_mode and _is_browser_disabled():
|
||||||
|
return False
|
||||||
|
return not (requires_web_search_mode and not _has_perplexity_api())
|
||||||
|
|
||||||
|
|
||||||
def register_tool(
|
def register_tool(
|
||||||
func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True
|
func: Callable[..., Any] | None = None,
|
||||||
|
*,
|
||||||
|
sandbox_execution: bool = True,
|
||||||
|
requires_browser_mode: bool = False,
|
||||||
|
requires_web_search_mode: bool = False,
|
||||||
) -> Callable[..., Any]:
|
) -> Callable[..., Any]:
|
||||||
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
|
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
sandbox_mode = _is_sandbox_mode()
|
||||||
|
if not _should_register_tool(
|
||||||
|
sandbox_execution=sandbox_execution,
|
||||||
|
requires_browser_mode=requires_browser_mode,
|
||||||
|
requires_web_search_mode=requires_web_search_mode,
|
||||||
|
):
|
||||||
|
return f
|
||||||
|
|
||||||
func_dict = {
|
func_dict = {
|
||||||
"name": f.__name__,
|
"name": f.__name__,
|
||||||
"function": f,
|
"function": f,
|
||||||
@@ -160,7 +210,6 @@ def register_tool(
|
|||||||
"sandbox_execution": sandbox_execution,
|
"sandbox_execution": sandbox_execution,
|
||||||
}
|
}
|
||||||
|
|
||||||
sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
|
||||||
if not sandbox_mode:
|
if not sandbox_mode:
|
||||||
try:
|
try:
|
||||||
schema_path = _get_schema_path(f)
|
schema_path = _get_schema_path(f)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ Structure your response to be comprehensive yet concise, emphasizing the most cr
|
|||||||
security implications and details."""
|
security implications and details."""
|
||||||
|
|
||||||
|
|
||||||
@register_tool(sandbox_execution=False)
|
@register_tool(sandbox_execution=False, requires_web_search_mode=True)
|
||||||
def web_search(query: str) -> dict[str, Any]:
|
def web_search(query: str) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
api_key = os.getenv("PERPLEXITY_API_KEY")
|
api_key = os.getenv("PERPLEXITY_API_KEY")
|
||||||
|
|||||||
95
tests/tools/test_tool_registration_modes.py
Normal file
95
tests/tools/test_tool_registration_modes.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
from strix.tools.registry import clear_registry
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_config_load(_cls: type[Config]) -> dict[str, dict[str, str]]:
|
||||||
|
return {"env": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_tools_module() -> ModuleType:
|
||||||
|
clear_registry()
|
||||||
|
|
||||||
|
for name in list(sys.modules):
|
||||||
|
if name == "strix.tools" or name.startswith("strix.tools."):
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
|
||||||
|
return importlib.import_module("strix.tools")
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_sandbox_registers_agents_graph_but_not_browser_or_web_search_when_disabled(
|
||||||
|
monkeypatch: Any,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setenv("STRIX_SANDBOX_MODE", "false")
|
||||||
|
monkeypatch.setenv("STRIX_DISABLE_BROWSER", "true")
|
||||||
|
monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False)
|
||||||
|
monkeypatch.setattr(Config, "load", classmethod(_empty_config_load))
|
||||||
|
|
||||||
|
tools = _reload_tools_module()
|
||||||
|
names = set(tools.get_tool_names())
|
||||||
|
|
||||||
|
assert "create_agent" in names
|
||||||
|
assert "browser_action" not in names
|
||||||
|
assert "web_search" not in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_sandbox_registers_sandbox_tools_but_not_non_sandbox_tools(
|
||||||
|
monkeypatch: Any,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setenv("STRIX_SANDBOX_MODE", "true")
|
||||||
|
monkeypatch.setenv("STRIX_DISABLE_BROWSER", "true")
|
||||||
|
monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False)
|
||||||
|
monkeypatch.setattr(Config, "load", classmethod(_empty_config_load))
|
||||||
|
|
||||||
|
tools = _reload_tools_module()
|
||||||
|
names = set(tools.get_tool_names())
|
||||||
|
|
||||||
|
assert "terminal_execute" in names
|
||||||
|
assert "python_action" in names
|
||||||
|
assert "list_requests" in names
|
||||||
|
assert "create_agent" not in names
|
||||||
|
assert "finish_scan" not in names
|
||||||
|
assert "load_skill" not in names
|
||||||
|
assert "browser_action" not in names
|
||||||
|
assert "web_search" not in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_skill_does_not_register_agents_graph_when_imported_directly(
|
||||||
|
monkeypatch: Any,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setenv("STRIX_SANDBOX_MODE", "true")
|
||||||
|
monkeypatch.setenv("STRIX_DISABLE_BROWSER", "true")
|
||||||
|
monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False)
|
||||||
|
monkeypatch.setattr(Config, "load", classmethod(_empty_config_load))
|
||||||
|
|
||||||
|
clear_registry()
|
||||||
|
for name in list(sys.modules):
|
||||||
|
if name == "strix.tools" or name.startswith("strix.tools."):
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
|
||||||
|
load_skill_module = importlib.import_module("strix.tools.load_skill.load_skill_actions")
|
||||||
|
registry = importlib.import_module("strix.tools.registry")
|
||||||
|
|
||||||
|
names_before = set(registry.get_tool_names())
|
||||||
|
assert "python_action" in names_before
|
||||||
|
assert "load_skill" not in names_before
|
||||||
|
assert "create_agent" not in names_before
|
||||||
|
|
||||||
|
state_type = type(
|
||||||
|
"DummyState",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"agent_id": "agent_test",
|
||||||
|
"context": {},
|
||||||
|
"update_context": lambda self, key, value: self.context.__setitem__(key, value),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = load_skill_module.load_skill(state_type(), "nmap")
|
||||||
|
|
||||||
|
names_after = set(registry.get_tool_names())
|
||||||
|
assert "create_agent" not in names_after
|
||||||
|
assert result["success"] is False
|
||||||
Reference in New Issue
Block a user