feat: add centralized Config class with auto-save to ~/.strix/cli-config.json
- Add Config class with all env var defaults in one place - Auto-load saved config on startup (env vars take precedence) - Auto-save config after successful LLM warm-up - Replace scattered os.getenv() calls with Config.get() Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
12
strix/config/__init__.py
Normal file
12
strix/config/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from strix.config.config import (
|
||||
Config,
|
||||
apply_saved_config,
|
||||
save_current_config,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"apply_saved_config",
|
||||
"save_current_config",
|
||||
]
|
||||
119
strix/config/config.py
Normal file
119
strix/config/config.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration Manager for Strix."""
|
||||
|
||||
# LLM Configuration
|
||||
strix_llm = "openai/gpt-5"
|
||||
llm_api_key = None
|
||||
llm_api_base = None
|
||||
openai_api_base = None
|
||||
litellm_base_url = None
|
||||
ollama_api_base = None
|
||||
strix_reasoning_effort = "high"
|
||||
llm_timeout = "300"
|
||||
llm_rate_limit_delay = "4.0"
|
||||
llm_rate_limit_concurrent = "1"
|
||||
|
||||
# Tool & Feature Configuration
|
||||
perplexity_api_key = None
|
||||
strix_disable_browser = "false"
|
||||
|
||||
# Runtime Configuration
|
||||
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
|
||||
strix_runtime_backend = "docker"
|
||||
strix_sandbox_execution_timeout = "500"
|
||||
strix_sandbox_connect_timeout = "10"
|
||||
|
||||
# Telemetry
|
||||
strix_telemetry = "1"
|
||||
|
||||
@classmethod
|
||||
def _tracked_names(cls) -> list[str]:
|
||||
return [
|
||||
k
|
||||
for k, v in vars(cls).items()
|
||||
if not k.startswith("_") and k[0].islower() and (v is None or isinstance(v, str))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def tracked_vars(cls) -> list[str]:
|
||||
return [name.upper() for name in cls._tracked_names()]
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> str | None:
|
||||
env_name = name.upper()
|
||||
default = getattr(cls, name, None)
|
||||
return os.getenv(env_name, default)
|
||||
|
||||
@classmethod
|
||||
def config_dir(cls) -> Path:
|
||||
return Path.home() / ".strix"
|
||||
|
||||
@classmethod
|
||||
def config_file(cls) -> Path:
|
||||
return cls.config_dir() / "cli-config.json"
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> dict[str, Any]:
|
||||
path = cls.config_file()
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data: dict[str, Any] = json.load(f)
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def save(cls, config: dict[str, Any]) -> bool:
|
||||
try:
|
||||
cls.config_dir().mkdir(parents=True, exist_ok=True)
|
||||
with cls.config_file().open("w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
except OSError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def apply_saved(cls) -> dict[str, str]:
|
||||
saved = cls.load()
|
||||
env_vars = saved.get("env", {})
|
||||
applied = {}
|
||||
|
||||
for var_name, var_value in env_vars.items():
|
||||
if var_name in cls.tracked_vars() and not os.getenv(var_name):
|
||||
os.environ[var_name] = var_value
|
||||
applied[var_name] = var_value
|
||||
|
||||
return applied
|
||||
|
||||
@classmethod
|
||||
def capture_current(cls) -> dict[str, Any]:
|
||||
env_vars = {}
|
||||
for var_name in cls.tracked_vars():
|
||||
value = os.getenv(var_name)
|
||||
if value:
|
||||
env_vars[var_name] = value
|
||||
return {"env": env_vars}
|
||||
|
||||
@classmethod
|
||||
def save_current(cls) -> bool:
|
||||
existing = cls.load().get("env", {})
|
||||
current = cls.capture_current().get("env", {})
|
||||
merged = {**existing, **current}
|
||||
return cls.save({"env": merged})
|
||||
|
||||
|
||||
def apply_saved_config() -> dict[str, str]:
|
||||
return Config.apply_saved()
|
||||
|
||||
|
||||
def save_current_config() -> bool:
|
||||
return Config.save_current()
|
||||
@@ -18,6 +18,7 @@ from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from strix.config import Config, apply_saved_config, save_current_config
|
||||
from strix.interface.cli import run_cli
|
||||
from strix.interface.tui import run_tui
|
||||
from strix.interface.utils import (
|
||||
@@ -198,13 +199,13 @@ async def warm_up_llm() -> None:
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
|
||||
api_key = os.getenv("LLM_API_KEY")
|
||||
model_name = Config.get("strix_llm")
|
||||
api_key = Config.get("llm_api_key")
|
||||
api_base = (
|
||||
os.getenv("LLM_API_BASE")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or os.getenv("LITELLM_BASE_URL")
|
||||
or os.getenv("OLLAMA_API_BASE")
|
||||
Config.get("llm_api_base")
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
@@ -212,7 +213,7 @@ async def warm_up_llm() -> None:
|
||||
{"role": "user", "content": "Reply with just 'OK'."},
|
||||
]
|
||||
|
||||
llm_timeout = int(os.getenv("LLM_TIMEOUT", "600"))
|
||||
llm_timeout = int(Config.get("llm_timeout")) # type: ignore[arg-type]
|
||||
|
||||
completion_kwargs: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
@@ -512,6 +513,8 @@ def main() -> None:
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
apply_saved_config()
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
check_docker_installed()
|
||||
@@ -520,6 +523,8 @@ def main() -> None:
|
||||
validate_environment()
|
||||
asyncio.run(warm_up_llm())
|
||||
|
||||
save_current_config()
|
||||
|
||||
args.run_name = generate_run_name(args.targets_info)
|
||||
|
||||
for target_info in args.targets_info:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import os
|
||||
from strix.config import Config
|
||||
|
||||
|
||||
class LLMConfig:
|
||||
@@ -10,7 +10,7 @@ class LLMConfig:
|
||||
timeout: int | None = None,
|
||||
scan_mode: str = "deep",
|
||||
):
|
||||
self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5")
|
||||
self.model_name = model_name or Config.get("strix_llm")
|
||||
|
||||
if not self.model_name:
|
||||
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
||||
@@ -18,6 +18,6 @@ class LLMConfig:
|
||||
self.enable_prompt_caching = enable_prompt_caching
|
||||
self.skills = skills or []
|
||||
|
||||
self.timeout = timeout or int(os.getenv("LLM_TIMEOUT", "300"))
|
||||
self.timeout = timeout or int(Config.get("llm_timeout")) # type: ignore[arg-type]
|
||||
|
||||
self.scan_mode = scan_mode if scan_mode in ["quick", "standard", "deep"] else "deep"
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
from strix.config import Config
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -154,13 +155,13 @@ def check_duplicate(
|
||||
|
||||
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
|
||||
|
||||
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
|
||||
api_key = os.getenv("LLM_API_KEY")
|
||||
model_name = Config.get("strix_llm")
|
||||
api_key = Config.get("llm_api_key")
|
||||
api_base = (
|
||||
os.getenv("LLM_API_BASE")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or os.getenv("LITELLM_BASE_URL")
|
||||
or os.getenv("OLLAMA_API_BASE")
|
||||
Config.get("llm_api_base")
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
)
|
||||
|
||||
messages = [
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -16,6 +15,7 @@ from jinja2 import (
|
||||
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
|
||||
from litellm.utils import supports_prompt_caching, supports_vision
|
||||
|
||||
from strix.config import Config
|
||||
from strix.llm.config import LLMConfig
|
||||
from strix.llm.memory_compressor import MemoryCompressor
|
||||
from strix.llm.request_queue import get_global_queue
|
||||
@@ -46,16 +46,14 @@ logger = logging.getLogger(__name__)
|
||||
litellm.drop_params = True
|
||||
litellm.modify_params = True
|
||||
|
||||
_LLM_API_KEY = os.getenv("LLM_API_KEY")
|
||||
_LLM_API_KEY = Config.get("llm_api_key")
|
||||
_LLM_API_BASE = (
|
||||
os.getenv("LLM_API_BASE")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or os.getenv("LITELLM_BASE_URL")
|
||||
or os.getenv("OLLAMA_API_BASE")
|
||||
Config.get("llm_api_base")
|
||||
or Config.get("openai_api_base")
|
||||
or Config.get("litellm_base_url")
|
||||
or Config.get("ollama_api_base")
|
||||
)
|
||||
_STRIX_REASONING_EFFORT = os.getenv(
|
||||
"STRIX_REASONING_EFFORT"
|
||||
) # "none", "minimal", "low", "medium", "high", or "xhigh"
|
||||
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
|
||||
|
||||
|
||||
class LLMRequestFailedError(Exception):
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
from strix.config import Config
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -150,7 +151,7 @@ class MemoryCompressor:
|
||||
timeout: int = 600,
|
||||
):
|
||||
self.max_images = max_images
|
||||
self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5")
|
||||
self.model_name = model_name or Config.get("strix_llm")
|
||||
self.timeout = timeout
|
||||
|
||||
if not self.model_name:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
@@ -8,16 +7,13 @@ from typing import Any
|
||||
from litellm import acompletion
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
|
||||
from strix.config import Config
|
||||
|
||||
|
||||
class LLMRequestQueue:
|
||||
def __init__(self, max_concurrent: int = 1, delay_between_requests: float = 4.0):
|
||||
rate_limit_delay = os.getenv("LLM_RATE_LIMIT_DELAY")
|
||||
if rate_limit_delay:
|
||||
delay_between_requests = float(rate_limit_delay)
|
||||
|
||||
rate_limit_concurrent = os.getenv("LLM_RATE_LIMIT_CONCURRENT")
|
||||
if rate_limit_concurrent:
|
||||
max_concurrent = int(rate_limit_concurrent)
|
||||
delay_between_requests = float(Config.get("llm_rate_limit_delay")) # type: ignore[arg-type]
|
||||
max_concurrent = int(Config.get("llm_rate_limit_concurrent")) # type: ignore[arg-type]
|
||||
|
||||
self.max_concurrent = max_concurrent
|
||||
self.delay_between_requests = delay_between_requests
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import os
|
||||
from strix.config import Config
|
||||
|
||||
from .runtime import AbstractRuntime
|
||||
|
||||
@@ -13,7 +13,7 @@ class SandboxInitializationError(Exception):
|
||||
|
||||
|
||||
def get_runtime() -> AbstractRuntime:
|
||||
runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker")
|
||||
runtime_backend = Config.get("strix_runtime_backend")
|
||||
|
||||
if runtime_backend == "docker":
|
||||
from .docker_runtime import DockerRuntime
|
||||
|
||||
@@ -15,11 +15,13 @@ from docker.models.containers import Container
|
||||
from requests.exceptions import ConnectionError as RequestsConnectionError
|
||||
from requests.exceptions import Timeout as RequestsTimeout
|
||||
|
||||
from strix.config import Config
|
||||
|
||||
from . import SandboxInitializationError
|
||||
from .runtime import AbstractRuntime, SandboxInfo
|
||||
|
||||
|
||||
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10")
|
||||
STRIX_IMAGE: str = Config.get("strix_image") # type: ignore[assignment]
|
||||
HOST_GATEWAY_HOSTNAME = "host.docker.internal"
|
||||
DOCKER_TIMEOUT = 60 # seconds
|
||||
TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import urllib.request
|
||||
@@ -7,6 +6,8 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
from strix.config import Config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from strix.telemetry.tracer import Tracer
|
||||
@@ -18,7 +19,7 @@ _SESSION_ID = uuid4().hex[:16]
|
||||
|
||||
|
||||
def _is_enabled() -> bool:
|
||||
return os.getenv("STRIX_TELEMETRY", "1").lower() not in ("0", "false", "no", "off")
|
||||
return (Config.get("strix_telemetry") or "1").lower() not in ("0", "false", "no", "off")
|
||||
|
||||
|
||||
def _is_first_run() -> bool:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
from strix.config import Config
|
||||
|
||||
from .executor import (
|
||||
execute_tool,
|
||||
execute_tool_invocation,
|
||||
@@ -22,9 +24,9 @@ from .registry import (
|
||||
|
||||
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||
|
||||
HAS_PERPLEXITY_API = bool(os.getenv("PERPLEXITY_API_KEY"))
|
||||
HAS_PERPLEXITY_API = bool(Config.get("perplexity_api_key"))
|
||||
|
||||
DISABLE_BROWSER = os.getenv("STRIX_DISABLE_BROWSER", "false").lower() == "true"
|
||||
DISABLE_BROWSER = (Config.get("strix_disable_browser") or "false").lower() == "true"
|
||||
|
||||
if not SANDBOX_MODE:
|
||||
from .agents_graph import * # noqa: F403
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from strix.config import Config
|
||||
|
||||
|
||||
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
|
||||
from strix.runtime import get_runtime
|
||||
@@ -17,8 +19,8 @@ from .registry import (
|
||||
)
|
||||
|
||||
|
||||
SANDBOX_EXECUTION_TIMEOUT = float(os.getenv("STRIX_SANDBOX_EXECUTION_TIMEOUT", "500"))
|
||||
SANDBOX_CONNECT_TIMEOUT = float(os.getenv("STRIX_SANDBOX_CONNECT_TIMEOUT", "10"))
|
||||
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout")) # type: ignore[arg-type]
|
||||
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout")) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any:
|
||||
|
||||
Reference in New Issue
Block a user