feat: add centralized Config class with auto-save to ~/.strix/cli-config.json

- Add Config class with all env var defaults in one place
- Auto-load saved config on startup (env vars take precedence)
- Auto-save config after successful LLM warm-up
- Replace scattered os.getenv() calls with Config.get()

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
0xallam
2026-01-09 21:24:08 -08:00
committed by Ahmed Allam
parent 52aa763d47
commit 83efe3816f
13 changed files with 184 additions and 45 deletions

12
strix/config/__init__.py Normal file
View File

@@ -0,0 +1,12 @@
from strix.config.config import (
Config,
apply_saved_config,
save_current_config,
)
__all__ = [
"Config",
"apply_saved_config",
"save_current_config",
]

119
strix/config/config.py Normal file
View File

@@ -0,0 +1,119 @@
import json
import os
from pathlib import Path
from typing import Any
class Config:
"""Configuration Manager for Strix."""
# LLM Configuration
strix_llm = "openai/gpt-5"
llm_api_key = None
llm_api_base = None
openai_api_base = None
litellm_base_url = None
ollama_api_base = None
strix_reasoning_effort = "high"
llm_timeout = "300"
llm_rate_limit_delay = "4.0"
llm_rate_limit_concurrent = "1"
# Tool & Feature Configuration
perplexity_api_key = None
strix_disable_browser = "false"
# Runtime Configuration
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
strix_runtime_backend = "docker"
strix_sandbox_execution_timeout = "500"
strix_sandbox_connect_timeout = "10"
# Telemetry
strix_telemetry = "1"
@classmethod
def _tracked_names(cls) -> list[str]:
return [
k
for k, v in vars(cls).items()
if not k.startswith("_") and k[0].islower() and (v is None or isinstance(v, str))
]
@classmethod
def tracked_vars(cls) -> list[str]:
return [name.upper() for name in cls._tracked_names()]
@classmethod
def get(cls, name: str) -> str | None:
env_name = name.upper()
default = getattr(cls, name, None)
return os.getenv(env_name, default)
@classmethod
def config_dir(cls) -> Path:
return Path.home() / ".strix"
@classmethod
def config_file(cls) -> Path:
return cls.config_dir() / "cli-config.json"
@classmethod
def load(cls) -> dict[str, Any]:
path = cls.config_file()
if not path.exists():
return {}
try:
with path.open("r", encoding="utf-8") as f:
data: dict[str, Any] = json.load(f)
return data
except (json.JSONDecodeError, OSError):
return {}
@classmethod
def save(cls, config: dict[str, Any]) -> bool:
try:
cls.config_dir().mkdir(parents=True, exist_ok=True)
with cls.config_file().open("w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
except OSError:
return False
else:
return True
@classmethod
def apply_saved(cls) -> dict[str, str]:
saved = cls.load()
env_vars = saved.get("env", {})
applied = {}
for var_name, var_value in env_vars.items():
if var_name in cls.tracked_vars() and not os.getenv(var_name):
os.environ[var_name] = var_value
applied[var_name] = var_value
return applied
@classmethod
def capture_current(cls) -> dict[str, Any]:
env_vars = {}
for var_name in cls.tracked_vars():
value = os.getenv(var_name)
if value:
env_vars[var_name] = value
return {"env": env_vars}
@classmethod
def save_current(cls) -> bool:
existing = cls.load().get("env", {})
current = cls.capture_current().get("env", {})
merged = {**existing, **current}
return cls.save({"env": merged})
def apply_saved_config() -> dict[str, str]:
return Config.apply_saved()
def save_current_config() -> bool:
return Config.save_current()

View File

@@ -18,6 +18,7 @@ from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from strix.config import Config, apply_saved_config, save_current_config
from strix.interface.cli import run_cli
from strix.interface.tui import run_tui
from strix.interface.utils import (
@@ -198,13 +199,13 @@ async def warm_up_llm() -> None:
console = Console()
try:
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
api_key = os.getenv("LLM_API_KEY")
model_name = Config.get("strix_llm")
api_key = Config.get("llm_api_key")
api_base = (
os.getenv("LLM_API_BASE")
or os.getenv("OPENAI_API_BASE")
or os.getenv("LITELLM_BASE_URL")
or os.getenv("OLLAMA_API_BASE")
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
)
test_messages = [
@@ -212,7 +213,7 @@ async def warm_up_llm() -> None:
{"role": "user", "content": "Reply with just 'OK'."},
]
llm_timeout = int(os.getenv("LLM_TIMEOUT", "600"))
llm_timeout = int(Config.get("llm_timeout")) # type: ignore[arg-type]
completion_kwargs: dict[str, Any] = {
"model": model_name,
@@ -512,6 +513,8 @@ def main() -> None:
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
apply_saved_config()
args = parse_arguments()
check_docker_installed()
@@ -520,6 +523,8 @@ def main() -> None:
validate_environment()
asyncio.run(warm_up_llm())
save_current_config()
args.run_name = generate_run_name(args.targets_info)
for target_info in args.targets_info:

View File

@@ -1,4 +1,4 @@
import os
from strix.config import Config
class LLMConfig:
@@ -10,7 +10,7 @@ class LLMConfig:
timeout: int | None = None,
scan_mode: str = "deep",
):
self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5")
self.model_name = model_name or Config.get("strix_llm")
if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty")
@@ -18,6 +18,6 @@ class LLMConfig:
self.enable_prompt_caching = enable_prompt_caching
self.skills = skills or []
self.timeout = timeout or int(os.getenv("LLM_TIMEOUT", "300"))
self.timeout = timeout or int(Config.get("llm_timeout")) # type: ignore[arg-type]
self.scan_mode = scan_mode if scan_mode in ["quick", "standard", "deep"] else "deep"

View File

@@ -1,11 +1,12 @@
import json
import logging
import os
import re
from typing import Any
import litellm
from strix.config import Config
logger = logging.getLogger(__name__)
@@ -154,13 +155,13 @@ def check_duplicate(
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
api_key = os.getenv("LLM_API_KEY")
model_name = Config.get("strix_llm")
api_key = Config.get("llm_api_key")
api_base = (
os.getenv("LLM_API_BASE")
or os.getenv("OPENAI_API_BASE")
or os.getenv("LITELLM_BASE_URL")
or os.getenv("OLLAMA_API_BASE")
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
)
messages = [

View File

@@ -1,6 +1,5 @@
import asyncio
import logging
import os
from collections.abc import AsyncIterator
from dataclasses import dataclass
from enum import Enum
@@ -16,6 +15,7 @@ from jinja2 import (
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
from litellm.utils import supports_prompt_caching, supports_vision
from strix.config import Config
from strix.llm.config import LLMConfig
from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.request_queue import get_global_queue
@@ -46,16 +46,14 @@ logger = logging.getLogger(__name__)
litellm.drop_params = True
litellm.modify_params = True
_LLM_API_KEY = os.getenv("LLM_API_KEY")
_LLM_API_KEY = Config.get("llm_api_key")
_LLM_API_BASE = (
os.getenv("LLM_API_BASE")
or os.getenv("OPENAI_API_BASE")
or os.getenv("LITELLM_BASE_URL")
or os.getenv("OLLAMA_API_BASE")
Config.get("llm_api_base")
or Config.get("openai_api_base")
or Config.get("litellm_base_url")
or Config.get("ollama_api_base")
)
_STRIX_REASONING_EFFORT = os.getenv(
"STRIX_REASONING_EFFORT"
) # "none", "minimal", "low", "medium", "high", or "xhigh"
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
class LLMRequestFailedError(Exception):

View File

@@ -1,9 +1,10 @@
import logging
import os
from typing import Any
import litellm
from strix.config import Config
logger = logging.getLogger(__name__)
@@ -150,7 +151,7 @@ class MemoryCompressor:
timeout: int = 600,
):
self.max_images = max_images
self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5")
self.model_name = model_name or Config.get("strix_llm")
self.timeout = timeout
if not self.model_name:

View File

@@ -1,5 +1,4 @@
import asyncio
import os
import threading
import time
from collections.abc import AsyncIterator
@@ -8,16 +7,13 @@ from typing import Any
from litellm import acompletion
from litellm.types.utils import ModelResponseStream
from strix.config import Config
class LLMRequestQueue:
def __init__(self, max_concurrent: int = 1, delay_between_requests: float = 4.0):
rate_limit_delay = os.getenv("LLM_RATE_LIMIT_DELAY")
if rate_limit_delay:
delay_between_requests = float(rate_limit_delay)
rate_limit_concurrent = os.getenv("LLM_RATE_LIMIT_CONCURRENT")
if rate_limit_concurrent:
max_concurrent = int(rate_limit_concurrent)
delay_between_requests = float(Config.get("llm_rate_limit_delay")) # type: ignore[arg-type]
max_concurrent = int(Config.get("llm_rate_limit_concurrent")) # type: ignore[arg-type]
self.max_concurrent = max_concurrent
self.delay_between_requests = delay_between_requests

View File

@@ -1,4 +1,4 @@
import os
from strix.config import Config
from .runtime import AbstractRuntime
@@ -13,7 +13,7 @@ class SandboxInitializationError(Exception):
def get_runtime() -> AbstractRuntime:
runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker")
runtime_backend = Config.get("strix_runtime_backend")
if runtime_backend == "docker":
from .docker_runtime import DockerRuntime

View File

@@ -15,11 +15,13 @@ from docker.models.containers import Container
from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.exceptions import Timeout as RequestsTimeout
from strix.config import Config
from . import SandboxInitializationError
from .runtime import AbstractRuntime, SandboxInfo
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10")
STRIX_IMAGE: str = Config.get("strix_image") # type: ignore[assignment]
HOST_GATEWAY_HOSTNAME = "host.docker.internal"
DOCKER_TIMEOUT = 60 # seconds
TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request

View File

@@ -1,5 +1,4 @@
import json
import os
import platform
import sys
import urllib.request
@@ -7,6 +6,8 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any
from uuid import uuid4
from strix.config import Config
if TYPE_CHECKING:
from strix.telemetry.tracer import Tracer
@@ -18,7 +19,7 @@ _SESSION_ID = uuid4().hex[:16]
def _is_enabled() -> bool:
return os.getenv("STRIX_TELEMETRY", "1").lower() not in ("0", "false", "no", "off")
return (Config.get("strix_telemetry") or "1").lower() not in ("0", "false", "no", "off")
def _is_first_run() -> bool:

View File

@@ -1,5 +1,7 @@
import os
from strix.config import Config
from .executor import (
execute_tool,
execute_tool_invocation,
@@ -22,9 +24,9 @@ from .registry import (
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
HAS_PERPLEXITY_API = bool(os.getenv("PERPLEXITY_API_KEY"))
HAS_PERPLEXITY_API = bool(Config.get("perplexity_api_key"))
DISABLE_BROWSER = os.getenv("STRIX_DISABLE_BROWSER", "false").lower() == "true"
DISABLE_BROWSER = (Config.get("strix_disable_browser") or "false").lower() == "true"
if not SANDBOX_MODE:
from .agents_graph import * # noqa: F403

View File

@@ -4,6 +4,8 @@ from typing import Any
import httpx
from strix.config import Config
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
from strix.runtime import get_runtime
@@ -17,8 +19,8 @@ from .registry import (
)
SANDBOX_EXECUTION_TIMEOUT = float(os.getenv("STRIX_SANDBOX_EXECUTION_TIMEOUT", "500"))
SANDBOX_CONNECT_TIMEOUT = float(os.getenv("STRIX_SANDBOX_CONNECT_TIMEOUT", "10"))
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout")) # type: ignore[arg-type]
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout")) # type: ignore[arg-type]
async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any: