feat: add centralized Config class with auto-save to ~/.strix/cli-config.json

- Add Config class with all env var defaults in one place
- Auto-load saved config on startup (env vars take precedence)
- Auto-save config after successful LLM warm-up
- Replace scattered os.getenv() calls with Config.get()

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
0xallam
2026-01-09 21:24:08 -08:00
committed by Ahmed Allam
parent 52aa763d47
commit 83efe3816f
13 changed files with 184 additions and 45 deletions

12
strix/config/__init__.py Normal file
View File

@@ -0,0 +1,12 @@
from strix.config.config import (
Config,
apply_saved_config,
save_current_config,
)
__all__ = [
"Config",
"apply_saved_config",
"save_current_config",
]

119
strix/config/config.py Normal file
View File

@@ -0,0 +1,119 @@
import json
import os
from pathlib import Path
from typing import Any
class Config:
"""Configuration Manager for Strix."""
# LLM Configuration
strix_llm = "openai/gpt-5"
llm_api_key = None
llm_api_base = None
openai_api_base = None
litellm_base_url = None
ollama_api_base = None
strix_reasoning_effort = "high"
llm_timeout = "300"
llm_rate_limit_delay = "4.0"
llm_rate_limit_concurrent = "1"
# Tool & Feature Configuration
perplexity_api_key = None
strix_disable_browser = "false"
# Runtime Configuration
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
strix_runtime_backend = "docker"
strix_sandbox_execution_timeout = "500"
strix_sandbox_connect_timeout = "10"
# Telemetry
strix_telemetry = "1"
@classmethod
def _tracked_names(cls) -> list[str]:
return [
k
for k, v in vars(cls).items()
if not k.startswith("_") and k[0].islower() and (v is None or isinstance(v, str))
]
@classmethod
def tracked_vars(cls) -> list[str]:
return [name.upper() for name in cls._tracked_names()]
@classmethod
def get(cls, name: str) -> str | None:
env_name = name.upper()
default = getattr(cls, name, None)
return os.getenv(env_name, default)
@classmethod
def config_dir(cls) -> Path:
return Path.home() / ".strix"
@classmethod
def config_file(cls) -> Path:
return cls.config_dir() / "cli-config.json"
@classmethod
def load(cls) -> dict[str, Any]:
path = cls.config_file()
if not path.exists():
return {}
try:
with path.open("r", encoding="utf-8") as f:
data: dict[str, Any] = json.load(f)
return data
except (json.JSONDecodeError, OSError):
return {}
@classmethod
def save(cls, config: dict[str, Any]) -> bool:
try:
cls.config_dir().mkdir(parents=True, exist_ok=True)
with cls.config_file().open("w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
except OSError:
return False
else:
return True
@classmethod
def apply_saved(cls) -> dict[str, str]:
saved = cls.load()
env_vars = saved.get("env", {})
applied = {}
for var_name, var_value in env_vars.items():
if var_name in cls.tracked_vars() and not os.getenv(var_name):
os.environ[var_name] = var_value
applied[var_name] = var_value
return applied
@classmethod
def capture_current(cls) -> dict[str, Any]:
env_vars = {}
for var_name in cls.tracked_vars():
value = os.getenv(var_name)
if value:
env_vars[var_name] = value
return {"env": env_vars}
@classmethod
def save_current(cls) -> bool:
existing = cls.load().get("env", {})
current = cls.capture_current().get("env", {})
merged = {**existing, **current}
return cls.save({"env": merged})
def apply_saved_config() -> dict[str, str]:
return Config.apply_saved()
def save_current_config() -> bool:
return Config.save_current()

View File

@@ -18,6 +18,7 @@ from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.text import Text from rich.text import Text
from strix.config import Config, apply_saved_config, save_current_config
from strix.interface.cli import run_cli from strix.interface.cli import run_cli
from strix.interface.tui import run_tui from strix.interface.tui import run_tui
from strix.interface.utils import ( from strix.interface.utils import (
@@ -198,13 +199,13 @@ async def warm_up_llm() -> None:
console = Console() console = Console()
try: try:
model_name = os.getenv("STRIX_LLM", "openai/gpt-5") model_name = Config.get("strix_llm")
api_key = os.getenv("LLM_API_KEY") api_key = Config.get("llm_api_key")
api_base = ( api_base = (
os.getenv("LLM_API_BASE") Config.get("llm_api_base")
or os.getenv("OPENAI_API_BASE") or Config.get("openai_api_base")
or os.getenv("LITELLM_BASE_URL") or Config.get("litellm_base_url")
or os.getenv("OLLAMA_API_BASE") or Config.get("ollama_api_base")
) )
test_messages = [ test_messages = [
@@ -212,7 +213,7 @@ async def warm_up_llm() -> None:
{"role": "user", "content": "Reply with just 'OK'."}, {"role": "user", "content": "Reply with just 'OK'."},
] ]
llm_timeout = int(os.getenv("LLM_TIMEOUT", "600")) llm_timeout = int(Config.get("llm_timeout")) # type: ignore[arg-type]
completion_kwargs: dict[str, Any] = { completion_kwargs: dict[str, Any] = {
"model": model_name, "model": model_name,
@@ -512,6 +513,8 @@ def main() -> None:
if sys.platform == "win32": if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
apply_saved_config()
args = parse_arguments() args = parse_arguments()
check_docker_installed() check_docker_installed()
@@ -520,6 +523,8 @@ def main() -> None:
validate_environment() validate_environment()
asyncio.run(warm_up_llm()) asyncio.run(warm_up_llm())
save_current_config()
args.run_name = generate_run_name(args.targets_info) args.run_name = generate_run_name(args.targets_info)
for target_info in args.targets_info: for target_info in args.targets_info:

View File

@@ -1,4 +1,4 @@
import os from strix.config import Config
class LLMConfig: class LLMConfig:
@@ -10,7 +10,7 @@ class LLMConfig:
timeout: int | None = None, timeout: int | None = None,
scan_mode: str = "deep", scan_mode: str = "deep",
): ):
self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5") self.model_name = model_name or Config.get("strix_llm")
if not self.model_name: if not self.model_name:
raise ValueError("STRIX_LLM environment variable must be set and not empty") raise ValueError("STRIX_LLM environment variable must be set and not empty")
@@ -18,6 +18,6 @@ class LLMConfig:
self.enable_prompt_caching = enable_prompt_caching self.enable_prompt_caching = enable_prompt_caching
self.skills = skills or [] self.skills = skills or []
self.timeout = timeout or int(os.getenv("LLM_TIMEOUT", "300")) self.timeout = timeout or int(Config.get("llm_timeout")) # type: ignore[arg-type]
self.scan_mode = scan_mode if scan_mode in ["quick", "standard", "deep"] else "deep" self.scan_mode = scan_mode if scan_mode in ["quick", "standard", "deep"] else "deep"

View File

@@ -1,11 +1,12 @@
import json import json
import logging import logging
import os
import re import re
from typing import Any from typing import Any
import litellm import litellm
from strix.config import Config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -154,13 +155,13 @@ def check_duplicate(
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned} comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
model_name = os.getenv("STRIX_LLM", "openai/gpt-5") model_name = Config.get("strix_llm")
api_key = os.getenv("LLM_API_KEY") api_key = Config.get("llm_api_key")
api_base = ( api_base = (
os.getenv("LLM_API_BASE") Config.get("llm_api_base")
or os.getenv("OPENAI_API_BASE") or Config.get("openai_api_base")
or os.getenv("LITELLM_BASE_URL") or Config.get("litellm_base_url")
or os.getenv("OLLAMA_API_BASE") or Config.get("ollama_api_base")
) )
messages = [ messages = [

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import logging import logging
import os
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@@ -16,6 +15,7 @@ from jinja2 import (
from litellm import completion_cost, stream_chunk_builder, supports_reasoning from litellm import completion_cost, stream_chunk_builder, supports_reasoning
from litellm.utils import supports_prompt_caching, supports_vision from litellm.utils import supports_prompt_caching, supports_vision
from strix.config import Config
from strix.llm.config import LLMConfig from strix.llm.config import LLMConfig
from strix.llm.memory_compressor import MemoryCompressor from strix.llm.memory_compressor import MemoryCompressor
from strix.llm.request_queue import get_global_queue from strix.llm.request_queue import get_global_queue
@@ -46,16 +46,14 @@ logger = logging.getLogger(__name__)
litellm.drop_params = True litellm.drop_params = True
litellm.modify_params = True litellm.modify_params = True
_LLM_API_KEY = os.getenv("LLM_API_KEY") _LLM_API_KEY = Config.get("llm_api_key")
_LLM_API_BASE = ( _LLM_API_BASE = (
os.getenv("LLM_API_BASE") Config.get("llm_api_base")
or os.getenv("OPENAI_API_BASE") or Config.get("openai_api_base")
or os.getenv("LITELLM_BASE_URL") or Config.get("litellm_base_url")
or os.getenv("OLLAMA_API_BASE") or Config.get("ollama_api_base")
) )
_STRIX_REASONING_EFFORT = os.getenv( _STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
"STRIX_REASONING_EFFORT"
) # "none", "minimal", "low", "medium", "high", or "xhigh"
class LLMRequestFailedError(Exception): class LLMRequestFailedError(Exception):

View File

@@ -1,9 +1,10 @@
import logging import logging
import os
from typing import Any from typing import Any
import litellm import litellm
from strix.config import Config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -150,7 +151,7 @@ class MemoryCompressor:
timeout: int = 600, timeout: int = 600,
): ):
self.max_images = max_images self.max_images = max_images
self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5") self.model_name = model_name or Config.get("strix_llm")
self.timeout = timeout self.timeout = timeout
if not self.model_name: if not self.model_name:

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import os
import threading import threading
import time import time
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
@@ -8,16 +7,13 @@ from typing import Any
from litellm import acompletion from litellm import acompletion
from litellm.types.utils import ModelResponseStream from litellm.types.utils import ModelResponseStream
from strix.config import Config
class LLMRequestQueue: class LLMRequestQueue:
def __init__(self, max_concurrent: int = 1, delay_between_requests: float = 4.0): def __init__(self, max_concurrent: int = 1, delay_between_requests: float = 4.0):
rate_limit_delay = os.getenv("LLM_RATE_LIMIT_DELAY") delay_between_requests = float(Config.get("llm_rate_limit_delay")) # type: ignore[arg-type]
if rate_limit_delay: max_concurrent = int(Config.get("llm_rate_limit_concurrent")) # type: ignore[arg-type]
delay_between_requests = float(rate_limit_delay)
rate_limit_concurrent = os.getenv("LLM_RATE_LIMIT_CONCURRENT")
if rate_limit_concurrent:
max_concurrent = int(rate_limit_concurrent)
self.max_concurrent = max_concurrent self.max_concurrent = max_concurrent
self.delay_between_requests = delay_between_requests self.delay_between_requests = delay_between_requests

View File

@@ -1,4 +1,4 @@
import os from strix.config import Config
from .runtime import AbstractRuntime from .runtime import AbstractRuntime
@@ -13,7 +13,7 @@ class SandboxInitializationError(Exception):
def get_runtime() -> AbstractRuntime: def get_runtime() -> AbstractRuntime:
runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker") runtime_backend = Config.get("strix_runtime_backend")
if runtime_backend == "docker": if runtime_backend == "docker":
from .docker_runtime import DockerRuntime from .docker_runtime import DockerRuntime

View File

@@ -15,11 +15,13 @@ from docker.models.containers import Container
from requests.exceptions import ConnectionError as RequestsConnectionError from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.exceptions import Timeout as RequestsTimeout from requests.exceptions import Timeout as RequestsTimeout
from strix.config import Config
from . import SandboxInitializationError from . import SandboxInitializationError
from .runtime import AbstractRuntime, SandboxInfo from .runtime import AbstractRuntime, SandboxInfo
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10") STRIX_IMAGE: str = Config.get("strix_image") # type: ignore[assignment]
HOST_GATEWAY_HOSTNAME = "host.docker.internal" HOST_GATEWAY_HOSTNAME = "host.docker.internal"
DOCKER_TIMEOUT = 60 # seconds DOCKER_TIMEOUT = 60 # seconds
TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request

View File

@@ -1,5 +1,4 @@
import json import json
import os
import platform import platform
import sys import sys
import urllib.request import urllib.request
@@ -7,6 +6,8 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from uuid import uuid4 from uuid import uuid4
from strix.config import Config
if TYPE_CHECKING: if TYPE_CHECKING:
from strix.telemetry.tracer import Tracer from strix.telemetry.tracer import Tracer
@@ -18,7 +19,7 @@ _SESSION_ID = uuid4().hex[:16]
def _is_enabled() -> bool: def _is_enabled() -> bool:
return os.getenv("STRIX_TELEMETRY", "1").lower() not in ("0", "false", "no", "off") return (Config.get("strix_telemetry") or "1").lower() not in ("0", "false", "no", "off")
def _is_first_run() -> bool: def _is_first_run() -> bool:

View File

@@ -1,5 +1,7 @@
import os import os
from strix.config import Config
from .executor import ( from .executor import (
execute_tool, execute_tool,
execute_tool_invocation, execute_tool_invocation,
@@ -22,9 +24,9 @@ from .registry import (
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true" SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
HAS_PERPLEXITY_API = bool(os.getenv("PERPLEXITY_API_KEY")) HAS_PERPLEXITY_API = bool(Config.get("perplexity_api_key"))
DISABLE_BROWSER = os.getenv("STRIX_DISABLE_BROWSER", "false").lower() == "true" DISABLE_BROWSER = (Config.get("strix_disable_browser") or "false").lower() == "true"
if not SANDBOX_MODE: if not SANDBOX_MODE:
from .agents_graph import * # noqa: F403 from .agents_graph import * # noqa: F403

View File

@@ -4,6 +4,8 @@ from typing import Any
import httpx import httpx
from strix.config import Config
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false": if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
from strix.runtime import get_runtime from strix.runtime import get_runtime
@@ -17,8 +19,8 @@ from .registry import (
) )
SANDBOX_EXECUTION_TIMEOUT = float(os.getenv("STRIX_SANDBOX_EXECUTION_TIMEOUT", "500")) SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout")) # type: ignore[arg-type]
SANDBOX_CONNECT_TIMEOUT = float(os.getenv("STRIX_SANDBOX_CONNECT_TIMEOUT", "10")) SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout")) # type: ignore[arg-type]
async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any: async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any: