From 8765b1895c344fe74890bc4403ad570a72815ce3 Mon Sep 17 00:00:00 2001 From: 0xallam Date: Thu, 19 Mar 2026 23:38:23 -0700 Subject: [PATCH] refactor: move tool availability checks into registration --- strix/tools/__init__.py | 62 +++---------- strix/tools/browser/browser_actions.py | 2 +- strix/tools/registry.py | 53 ++++++++++- strix/tools/web_search/web_search_actions.py | 2 +- tests/tools/test_tool_registration_modes.py | 95 ++++++++++++++++++++ 5 files changed, 161 insertions(+), 53 deletions(-) create mode 100644 tests/tools/test_tool_registration_modes.py diff --git a/strix/tools/__init__.py b/strix/tools/__init__.py index 9eff25e..17299d4 100644 --- a/strix/tools/__init__.py +++ b/strix/tools/__init__.py @@ -1,7 +1,5 @@ -import os - -from strix.config import Config - +from .agents_graph import * # noqa: F403 +from .browser import * # noqa: F403 from .executor import ( execute_tool, execute_tool_invocation, @@ -11,6 +9,12 @@ from .executor import ( remove_screenshot_from_result, validate_tool_availability, ) +from .file_edit import * # noqa: F403 +from .finish import * # noqa: F403 +from .load_skill import * # noqa: F403 +from .notes import * # noqa: F403 +from .proxy import * # noqa: F403 +from .python import * # noqa: F403 from .registry import ( ImplementedInClientSideOnlyError, get_tool_by_name, @@ -20,53 +24,13 @@ from .registry import ( register_tool, tools, ) +from .reporting import * # noqa: F403 +from .terminal import * # noqa: F403 +from .thinking import * # noqa: F403 +from .todo import * # noqa: F403 +from .web_search import * # noqa: F403 -SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" - - -def _is_browser_disabled() -> bool: - if os.getenv("STRIX_DISABLE_BROWSER", "").lower() == "true": - return True - val: str = Config.load().get("env", {}).get("STRIX_DISABLE_BROWSER", "") - return str(val).lower() == "true" - - -DISABLE_BROWSER = _is_browser_disabled() - - -def _has_perplexity_api() -> bool: - if os.getenv("PERPLEXITY_API_KEY"): - return True - return bool(Config.load().get("env", {}).get("PERPLEXITY_API_KEY")) - - -if not SANDBOX_MODE: - from .agents_graph import * # noqa: F403 - - if not DISABLE_BROWSER: - from .browser import * # noqa: F403 - from .file_edit import * # noqa: F403 - from .finish import * # noqa: F403 - from .load_skill import * # noqa: F403 - from .notes import * # noqa: F403 - from .proxy import * # noqa: F403 - from .python import * # noqa: F403 - from .reporting import * # noqa: F403 - from .terminal import * # noqa: F403 - from .thinking import * # noqa: F403 - from .todo import * # noqa: F403 - - if _has_perplexity_api(): - from .web_search import * # noqa: F403 -else: - if not DISABLE_BROWSER: - from .browser import * # noqa: F403 - from .file_edit import * # noqa: F403 - from .proxy import * # noqa: F403 - from .python import * # noqa: F403 - from .terminal import * # noqa: F403 - __all__ = [ "ImplementedInClientSideOnlyError", "execute_tool", diff --git a/strix/tools/browser/browser_actions.py b/strix/tools/browser/browser_actions.py index 5726df0..2a3c416 100644 --- a/strix/tools/browser/browser_actions.py +++ b/strix/tools/browser/browser_actions.py @@ -180,7 +180,7 @@ def _handle_utility_actions( raise ValueError(f"Unknown utility action: {action}") -@register_tool +@register_tool(requires_browser_mode=True) def browser_action( action: BrowserAction, url: str | None = None, diff --git a/strix/tools/registry.py b/strix/tools/registry.py index 7313bc3..329db13 100644 --- a/strix/tools/registry.py +++ b/strix/tools/registry.py @@ -149,10 +149,60 @@ def _get_schema_path(func: Callable[..., Any]) -> Path | None: return get_strix_resource_path("tools", folder, schema_file) +def _is_sandbox_mode() -> bool: + return os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" + + +def _is_browser_disabled() -> bool: + if os.getenv("STRIX_DISABLE_BROWSER", "").lower() == "true": + return True + + from strix.config import Config + + val: str = Config.load().get("env", {}).get("STRIX_DISABLE_BROWSER", "") + return str(val).lower() == "true" + + +def _has_perplexity_api() -> bool: + if os.getenv("PERPLEXITY_API_KEY"): + return True + + from strix.config import Config + + return bool(Config.load().get("env", {}).get("PERPLEXITY_API_KEY")) + + +def _should_register_tool( + *, + sandbox_execution: bool, + requires_browser_mode: bool, + requires_web_search_mode: bool, +) -> bool: + sandbox_mode = _is_sandbox_mode() + + if sandbox_mode and not sandbox_execution: + return False + if requires_browser_mode and _is_browser_disabled(): + return False + return not (requires_web_search_mode and not _has_perplexity_api()) + + def register_tool( - func: Callable[..., Any] | None = None, *, sandbox_execution: bool = True + func: Callable[..., Any] | None = None, + *, + sandbox_execution: bool = True, + requires_browser_mode: bool = False, + requires_web_search_mode: bool = False, ) -> Callable[..., Any]: def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + sandbox_mode = _is_sandbox_mode() + if not _should_register_tool( + sandbox_execution=sandbox_execution, + requires_browser_mode=requires_browser_mode, + requires_web_search_mode=requires_web_search_mode, + ): + return f + func_dict = { "name": f.__name__, "function": f, @@ -160,7 +210,6 @@ def register_tool( "sandbox_execution": sandbox_execution, } - sandbox_mode = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" if not sandbox_mode: try: schema_path = _get_schema_path(f) diff --git a/strix/tools/web_search/web_search_actions.py b/strix/tools/web_search/web_search_actions.py index f2b6fcf..e88eba7 100644 --- a/strix/tools/web_search/web_search_actions.py +++ b/strix/tools/web_search/web_search_actions.py @@ -31,7 +31,7 @@ Structure your response to be comprehensive yet concise, emphasizing the most cr security implications and details.""" -@register_tool(sandbox_execution=False) +@register_tool(sandbox_execution=False, requires_web_search_mode=True) def web_search(query: str) -> dict[str, Any]: try: api_key = os.getenv("PERPLEXITY_API_KEY") diff --git a/tests/tools/test_tool_registration_modes.py b/tests/tools/test_tool_registration_modes.py new file mode 100644 index 0000000..1336a48 --- /dev/null +++ b/tests/tools/test_tool_registration_modes.py @@ -0,0 +1,95 @@ +import importlib +import sys +from types import ModuleType +from typing import Any + +from strix.config import Config +from strix.tools.registry import clear_registry + + +def _empty_config_load(_cls: type[Config]) -> dict[str, dict[str, str]]: + return {"env": {}} + + +def _reload_tools_module() -> ModuleType: + clear_registry() + + for name in list(sys.modules): + if name == "strix.tools" or name.startswith("strix.tools."): + sys.modules.pop(name, None) + + return importlib.import_module("strix.tools") + + +def test_non_sandbox_registers_agents_graph_but_not_browser_or_web_search_when_disabled( + monkeypatch: Any, +) -> None: + monkeypatch.setenv("STRIX_SANDBOX_MODE", "false") + monkeypatch.setenv("STRIX_DISABLE_BROWSER", "true") + monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) + monkeypatch.setattr(Config, "load", classmethod(_empty_config_load)) + + tools = _reload_tools_module() + names = set(tools.get_tool_names()) + + assert "create_agent" in names + assert "browser_action" not in names + assert "web_search" not in names + + +def test_sandbox_registers_sandbox_tools_but_not_non_sandbox_tools( + monkeypatch: Any, +) -> None: + monkeypatch.setenv("STRIX_SANDBOX_MODE", "true") + monkeypatch.setenv("STRIX_DISABLE_BROWSER", "true") + monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) + monkeypatch.setattr(Config, "load", classmethod(_empty_config_load)) + + tools = _reload_tools_module() + names = set(tools.get_tool_names()) + + assert "terminal_execute" in names + assert "python_action" in names + assert "list_requests" in names + assert "create_agent" not in names + assert "finish_scan" not in names + assert "load_skill" not in names + assert "browser_action" not in names + assert "web_search" not in names + + +def test_load_skill_does_not_register_agents_graph_when_imported_directly( + monkeypatch: Any, +) -> None: + monkeypatch.setenv("STRIX_SANDBOX_MODE", "true") + monkeypatch.setenv("STRIX_DISABLE_BROWSER", "true") + monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) + monkeypatch.setattr(Config, "load", classmethod(_empty_config_load)) + + clear_registry() + for name in list(sys.modules): + if name == "strix.tools" or name.startswith("strix.tools."): + sys.modules.pop(name, None) + + load_skill_module = importlib.import_module("strix.tools.load_skill.load_skill_actions") + registry = importlib.import_module("strix.tools.registry") + + names_before = set(registry.get_tool_names()) + assert "python_action" in names_before + assert "load_skill" not in names_before + assert "create_agent" not in names_before + + state_type = type( + "DummyState", + (), + { + "agent_id": "agent_test", + "context": {}, + "update_context": lambda self, key, value: self.context.__setitem__(key, value), + }, + ) + result = load_skill_module.load_skill(state_type(), "nmap") + + names_after = set(registry.get_tool_names()) + assert "create_agent" not in names_after + assert result["success"] is False