From 83efe3816fd384fc98718ac7a32a2d567de71013 Mon Sep 17 00:00:00 2001 From: 0xallam Date: Fri, 9 Jan 2026 21:24:08 -0800 Subject: [PATCH] feat: add centralized Config class with auto-save to ~/.strix/cli-config.json - Add Config class with all env var defaults in one place - Auto-load saved config on startup (env vars take precedence) - Auto-save config after successful LLM warm-up - Replace scattered os.getenv() calls with Config.get() Co-Authored-By: Claude Opus 4.5 --- strix/config/__init__.py | 12 ++++ strix/config/config.py | 119 ++++++++++++++++++++++++++++++++ strix/interface/main.py | 19 +++-- strix/llm/config.py | 6 +- strix/llm/dedupe.py | 15 ++-- strix/llm/llm.py | 16 ++--- strix/llm/memory_compressor.py | 5 +- strix/llm/request_queue.py | 12 ++-- strix/runtime/__init__.py | 4 +- strix/runtime/docker_runtime.py | 4 +- strix/telemetry/posthog.py | 5 +- strix/tools/__init__.py | 6 +- strix/tools/executor.py | 6 +- 13 files changed, 184 insertions(+), 45 deletions(-) create mode 100644 strix/config/__init__.py create mode 100644 strix/config/config.py diff --git a/strix/config/__init__.py b/strix/config/__init__.py new file mode 100644 index 0000000..328c138 --- /dev/null +++ b/strix/config/__init__.py @@ -0,0 +1,12 @@ +from strix.config.config import ( + Config, + apply_saved_config, + save_current_config, +) + + +__all__ = [ + "Config", + "apply_saved_config", + "save_current_config", +] diff --git a/strix/config/config.py b/strix/config/config.py new file mode 100644 index 0000000..abed9d8 --- /dev/null +++ b/strix/config/config.py @@ -0,0 +1,119 @@ +import json +import os +from pathlib import Path +from typing import Any + + +class Config: + """Configuration Manager for Strix.""" + + # LLM Configuration + strix_llm = "openai/gpt-5" + llm_api_key = None + llm_api_base = None + openai_api_base = None + litellm_base_url = None + ollama_api_base = None + strix_reasoning_effort = "high" + llm_timeout = "300" + llm_rate_limit_delay = "4.0" + llm_rate_limit_concurrent = "1" + + # Tool & Feature Configuration + perplexity_api_key = None + strix_disable_browser = "false" + + # Runtime Configuration + strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10" + strix_runtime_backend = "docker" + strix_sandbox_execution_timeout = "500" + strix_sandbox_connect_timeout = "10" + + # Telemetry + strix_telemetry = "1" + + @classmethod + def _tracked_names(cls) -> list[str]: + return [ + k + for k, v in vars(cls).items() + if not k.startswith("_") and k[0].islower() and (v is None or isinstance(v, str)) + ] + + @classmethod + def tracked_vars(cls) -> list[str]: + return [name.upper() for name in cls._tracked_names()] + + @classmethod + def get(cls, name: str) -> str | None: + env_name = name.upper() + default = getattr(cls, name, None) + return os.getenv(env_name, default) + + @classmethod + def config_dir(cls) -> Path: + return Path.home() / ".strix" + + @classmethod + def config_file(cls) -> Path: + return cls.config_dir() / "cli-config.json" + + @classmethod + def load(cls) -> dict[str, Any]: + path = cls.config_file() + if not path.exists(): + return {} + try: + with path.open("r", encoding="utf-8") as f: + data: dict[str, Any] = json.load(f) + return data + except (json.JSONDecodeError, OSError): + return {} + + @classmethod + def save(cls, config: dict[str, Any]) -> bool: + try: + cls.config_dir().mkdir(parents=True, exist_ok=True) + with cls.config_file().open("w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + except OSError: + return False + else: + return True + + @classmethod + def apply_saved(cls) -> dict[str, str]: + saved = cls.load() + env_vars = saved.get("env", {}) + applied = {} + + for var_name, var_value in env_vars.items(): + if var_name in cls.tracked_vars() and not os.getenv(var_name): + os.environ[var_name] = var_value + applied[var_name] = var_value + + return applied + + @classmethod + def capture_current(cls) -> dict[str, Any]: + env_vars = {} + for var_name in cls.tracked_vars(): + value = os.getenv(var_name) + if value: + env_vars[var_name] = value + return {"env": env_vars} + + @classmethod + def save_current(cls) -> bool: + existing = cls.load().get("env", {}) + current = cls.capture_current().get("env", {}) + merged = {**existing, **current} + return cls.save({"env": merged}) + + +def apply_saved_config() -> dict[str, str]: + return Config.apply_saved() + + +def save_current_config() -> bool: + return Config.save_current() diff --git a/strix/interface/main.py b/strix/interface/main.py index f40c5a2..2b567b9 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -18,6 +18,7 @@ from rich.console import Console from rich.panel import Panel from rich.text import Text +from strix.config import Config, apply_saved_config, save_current_config from strix.interface.cli import run_cli from strix.interface.tui import run_tui from strix.interface.utils import ( @@ -198,13 +199,13 @@ async def warm_up_llm() -> None: console = Console() try: - model_name = os.getenv("STRIX_LLM", "openai/gpt-5") - api_key = os.getenv("LLM_API_KEY") + model_name = Config.get("strix_llm") + api_key = Config.get("llm_api_key") api_base = ( - os.getenv("LLM_API_BASE") - or os.getenv("OPENAI_API_BASE") - or os.getenv("LITELLM_BASE_URL") - or os.getenv("OLLAMA_API_BASE") + Config.get("llm_api_base") + or Config.get("openai_api_base") + or Config.get("litellm_base_url") + or Config.get("ollama_api_base") ) test_messages = [ @@ -212,7 +213,7 @@ async def warm_up_llm() -> None: {"role": "user", "content": "Reply with just 'OK'."}, ] - llm_timeout = int(os.getenv("LLM_TIMEOUT", "600")) + llm_timeout = int(Config.get("llm_timeout")) # type: ignore[arg-type] completion_kwargs: dict[str, Any] = { "model": model_name, @@ -512,6 +513,8 @@ def main() -> None: if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + apply_saved_config() + args = parse_arguments() check_docker_installed() @@ -520,6 +523,8 @@ def main() -> None: validate_environment() asyncio.run(warm_up_llm()) + save_current_config() + args.run_name = generate_run_name(args.targets_info) for target_info in args.targets_info: diff --git a/strix/llm/config.py b/strix/llm/config.py index 9984a23..609b628 100644 --- a/strix/llm/config.py +++ b/strix/llm/config.py @@ -1,4 +1,4 @@ -import os +from strix.config import Config class LLMConfig: @@ -10,7 +10,7 @@ class LLMConfig: timeout: int | None = None, scan_mode: str = "deep", ): - self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5") + self.model_name = model_name or Config.get("strix_llm") if not self.model_name: raise ValueError("STRIX_LLM environment variable must be set and not empty") @@ -18,6 +18,6 @@ class LLMConfig: self.enable_prompt_caching = enable_prompt_caching self.skills = skills or [] - self.timeout = timeout or int(os.getenv("LLM_TIMEOUT", "300")) + self.timeout = timeout or int(Config.get("llm_timeout")) # type: ignore[arg-type] self.scan_mode = scan_mode if scan_mode in ["quick", "standard", "deep"] else "deep" diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index 99d3505..31aa74d 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -1,11 +1,12 @@ import json import logging -import os import re from typing import Any import litellm +from strix.config import Config + logger = logging.getLogger(__name__) @@ -154,13 +155,13 @@ def check_duplicate( comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned} - model_name = os.getenv("STRIX_LLM", "openai/gpt-5") - api_key = os.getenv("LLM_API_KEY") + model_name = Config.get("strix_llm") + api_key = Config.get("llm_api_key") api_base = ( - os.getenv("LLM_API_BASE") - or os.getenv("OPENAI_API_BASE") - or os.getenv("LITELLM_BASE_URL") - or os.getenv("OLLAMA_API_BASE") + Config.get("llm_api_base") + or Config.get("openai_api_base") + or Config.get("litellm_base_url") + or Config.get("ollama_api_base") ) messages = [ diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 816d6f2..3bbc27b 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -1,6 +1,5 @@ import asyncio import logging -import os from collections.abc import AsyncIterator from dataclasses import dataclass from enum import Enum @@ -16,6 +15,7 @@ from jinja2 import ( from litellm import completion_cost, stream_chunk_builder, supports_reasoning from litellm.utils import supports_prompt_caching, supports_vision +from strix.config import Config from strix.llm.config import LLMConfig from strix.llm.memory_compressor import MemoryCompressor from strix.llm.request_queue import get_global_queue @@ -46,16 +46,14 @@ logger = logging.getLogger(__name__) litellm.drop_params = True litellm.modify_params = True -_LLM_API_KEY = os.getenv("LLM_API_KEY") +_LLM_API_KEY = Config.get("llm_api_key") _LLM_API_BASE = ( - os.getenv("LLM_API_BASE") - or os.getenv("OPENAI_API_BASE") - or os.getenv("LITELLM_BASE_URL") - or os.getenv("OLLAMA_API_BASE") + Config.get("llm_api_base") + or Config.get("openai_api_base") + or Config.get("litellm_base_url") + or Config.get("ollama_api_base") ) -_STRIX_REASONING_EFFORT = os.getenv( - "STRIX_REASONING_EFFORT" -) # "none", "minimal", "low", "medium", "high", or "xhigh" +_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort") class LLMRequestFailedError(Exception): diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index b5779d8..bfc6480 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -1,9 +1,10 @@ import logging -import os from typing import Any import litellm +from strix.config import Config + logger = logging.getLogger(__name__) @@ -150,7 +151,7 @@ class MemoryCompressor: timeout: int = 600, ): self.max_images = max_images - self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5") + self.model_name = model_name or Config.get("strix_llm") self.timeout = timeout if not self.model_name: diff --git a/strix/llm/request_queue.py b/strix/llm/request_queue.py index 0b68737..c3ddc37 100644 --- a/strix/llm/request_queue.py +++ b/strix/llm/request_queue.py @@ -1,5 +1,4 @@ import asyncio -import os import threading import time from collections.abc import AsyncIterator @@ -8,16 +7,13 @@ from typing import Any from litellm import acompletion from litellm.types.utils import ModelResponseStream +from strix.config import Config + class LLMRequestQueue: def __init__(self, max_concurrent: int = 1, delay_between_requests: float = 4.0): - rate_limit_delay = os.getenv("LLM_RATE_LIMIT_DELAY") - if rate_limit_delay: - delay_between_requests = float(rate_limit_delay) - - rate_limit_concurrent = os.getenv("LLM_RATE_LIMIT_CONCURRENT") - if rate_limit_concurrent: - max_concurrent = int(rate_limit_concurrent) + delay_between_requests = float(Config.get("llm_rate_limit_delay")) # type: ignore[arg-type] + max_concurrent = int(Config.get("llm_rate_limit_concurrent")) # type: ignore[arg-type] self.max_concurrent = max_concurrent self.delay_between_requests = delay_between_requests diff --git a/strix/runtime/__init__.py b/strix/runtime/__init__.py index 49b83d9..ae01f38 100644 --- a/strix/runtime/__init__.py +++ b/strix/runtime/__init__.py @@ -1,4 +1,4 @@ -import os +from strix.config import Config from .runtime import AbstractRuntime @@ -13,7 +13,7 @@ class SandboxInitializationError(Exception): def get_runtime() -> AbstractRuntime: - runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker") + runtime_backend = Config.get("strix_runtime_backend") if runtime_backend == "docker": from .docker_runtime import DockerRuntime diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index 9eac17a..d912158 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -15,11 +15,13 @@ from docker.models.containers import Container from requests.exceptions import ConnectionError as RequestsConnectionError from requests.exceptions import Timeout as RequestsTimeout +from strix.config import Config + from . import SandboxInitializationError from .runtime import AbstractRuntime, SandboxInfo -STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10") +STRIX_IMAGE: str = Config.get("strix_image") # type: ignore[assignment] HOST_GATEWAY_HOSTNAME = "host.docker.internal" DOCKER_TIMEOUT = 60 # seconds TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request diff --git a/strix/telemetry/posthog.py b/strix/telemetry/posthog.py index 8d3c355..fd66bcc 100644 --- a/strix/telemetry/posthog.py +++ b/strix/telemetry/posthog.py @@ -1,5 +1,4 @@ import json -import os import platform import sys import urllib.request @@ -7,6 +6,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any from uuid import uuid4 +from strix.config import Config + if TYPE_CHECKING: from strix.telemetry.tracer import Tracer @@ -18,7 +19,7 @@ _SESSION_ID = uuid4().hex[:16] def _is_enabled() -> bool: - return os.getenv("STRIX_TELEMETRY", "1").lower() not in ("0", "false", "no", "off") + return (Config.get("strix_telemetry") or "1").lower() not in ("0", "false", "no", "off") def _is_first_run() -> bool: diff --git a/strix/tools/__init__.py b/strix/tools/__init__.py index 4193d41..1c49472 100644 --- a/strix/tools/__init__.py +++ b/strix/tools/__init__.py @@ -1,5 +1,7 @@ import os +from strix.config import Config + from .executor import ( execute_tool, execute_tool_invocation, @@ -22,9 +24,9 @@ from .registry import ( SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" -HAS_PERPLEXITY_API = bool(os.getenv("PERPLEXITY_API_KEY")) +HAS_PERPLEXITY_API = bool(Config.get("perplexity_api_key")) -DISABLE_BROWSER = os.getenv("STRIX_DISABLE_BROWSER", "false").lower() == "true" +DISABLE_BROWSER = (Config.get("strix_disable_browser") or "false").lower() == "true" if not SANDBOX_MODE: from .agents_graph import * # noqa: F403 diff --git a/strix/tools/executor.py b/strix/tools/executor.py index 3d36bc8..1ec0375 100644 --- a/strix/tools/executor.py +++ b/strix/tools/executor.py @@ -4,6 +4,8 @@ from typing import Any import httpx +from strix.config import Config + if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false": from strix.runtime import get_runtime @@ -17,8 +19,8 @@ from .registry import ( ) -SANDBOX_EXECUTION_TIMEOUT = float(os.getenv("STRIX_SANDBOX_EXECUTION_TIMEOUT", "500")) -SANDBOX_CONNECT_TIMEOUT = float(os.getenv("STRIX_SANDBOX_CONNECT_TIMEOUT", "10")) +SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout")) # type: ignore[arg-type] +SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout")) # type: ignore[arg-type] async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any: