feat: add centralized Config class with auto-save to ~/.strix/cli-config.json
- Add Config class with all env var defaults in one place - Auto-load saved config on startup (env vars take precedence) - Auto-save config after successful LLM warm-up - Replace scattered os.getenv() calls with Config.get() Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
12
strix/config/__init__.py
Normal file
12
strix/config/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from strix.config.config import (
|
||||||
|
Config,
|
||||||
|
apply_saved_config,
|
||||||
|
save_current_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Config",
|
||||||
|
"apply_saved_config",
|
||||||
|
"save_current_config",
|
||||||
|
]
|
||||||
119
strix/config/config.py
Normal file
119
strix/config/config.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration Manager for Strix."""
|
||||||
|
|
||||||
|
# LLM Configuration
|
||||||
|
strix_llm = "openai/gpt-5"
|
||||||
|
llm_api_key = None
|
||||||
|
llm_api_base = None
|
||||||
|
openai_api_base = None
|
||||||
|
litellm_base_url = None
|
||||||
|
ollama_api_base = None
|
||||||
|
strix_reasoning_effort = "high"
|
||||||
|
llm_timeout = "300"
|
||||||
|
llm_rate_limit_delay = "4.0"
|
||||||
|
llm_rate_limit_concurrent = "1"
|
||||||
|
|
||||||
|
# Tool & Feature Configuration
|
||||||
|
perplexity_api_key = None
|
||||||
|
strix_disable_browser = "false"
|
||||||
|
|
||||||
|
# Runtime Configuration
|
||||||
|
strix_image = "ghcr.io/usestrix/strix-sandbox:0.1.10"
|
||||||
|
strix_runtime_backend = "docker"
|
||||||
|
strix_sandbox_execution_timeout = "500"
|
||||||
|
strix_sandbox_connect_timeout = "10"
|
||||||
|
|
||||||
|
# Telemetry
|
||||||
|
strix_telemetry = "1"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _tracked_names(cls) -> list[str]:
|
||||||
|
return [
|
||||||
|
k
|
||||||
|
for k, v in vars(cls).items()
|
||||||
|
if not k.startswith("_") and k[0].islower() and (v is None or isinstance(v, str))
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tracked_vars(cls) -> list[str]:
|
||||||
|
return [name.upper() for name in cls._tracked_names()]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, name: str) -> str | None:
|
||||||
|
env_name = name.upper()
|
||||||
|
default = getattr(cls, name, None)
|
||||||
|
return os.getenv(env_name, default)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_dir(cls) -> Path:
|
||||||
|
return Path.home() / ".strix"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_file(cls) -> Path:
|
||||||
|
return cls.config_dir() / "cli-config.json"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls) -> dict[str, Any]:
|
||||||
|
path = cls.config_file()
|
||||||
|
if not path.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
data: dict[str, Any] = json.load(f)
|
||||||
|
return data
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save(cls, config: dict[str, Any]) -> bool:
|
||||||
|
try:
|
||||||
|
cls.config_dir().mkdir(parents=True, exist_ok=True)
|
||||||
|
with cls.config_file().open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply_saved(cls) -> dict[str, str]:
|
||||||
|
saved = cls.load()
|
||||||
|
env_vars = saved.get("env", {})
|
||||||
|
applied = {}
|
||||||
|
|
||||||
|
for var_name, var_value in env_vars.items():
|
||||||
|
if var_name in cls.tracked_vars() and not os.getenv(var_name):
|
||||||
|
os.environ[var_name] = var_value
|
||||||
|
applied[var_name] = var_value
|
||||||
|
|
||||||
|
return applied
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def capture_current(cls) -> dict[str, Any]:
|
||||||
|
env_vars = {}
|
||||||
|
for var_name in cls.tracked_vars():
|
||||||
|
value = os.getenv(var_name)
|
||||||
|
if value:
|
||||||
|
env_vars[var_name] = value
|
||||||
|
return {"env": env_vars}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_current(cls) -> bool:
|
||||||
|
existing = cls.load().get("env", {})
|
||||||
|
current = cls.capture_current().get("env", {})
|
||||||
|
merged = {**existing, **current}
|
||||||
|
return cls.save({"env": merged})
|
||||||
|
|
||||||
|
|
||||||
|
def apply_saved_config() -> dict[str, str]:
|
||||||
|
return Config.apply_saved()
|
||||||
|
|
||||||
|
|
||||||
|
def save_current_config() -> bool:
|
||||||
|
return Config.save_current()
|
||||||
@@ -18,6 +18,7 @@ from rich.console import Console
|
|||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
|
from strix.config import Config, apply_saved_config, save_current_config
|
||||||
from strix.interface.cli import run_cli
|
from strix.interface.cli import run_cli
|
||||||
from strix.interface.tui import run_tui
|
from strix.interface.tui import run_tui
|
||||||
from strix.interface.utils import (
|
from strix.interface.utils import (
|
||||||
@@ -198,13 +199,13 @@ async def warm_up_llm() -> None:
|
|||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
|
model_name = Config.get("strix_llm")
|
||||||
api_key = os.getenv("LLM_API_KEY")
|
api_key = Config.get("llm_api_key")
|
||||||
api_base = (
|
api_base = (
|
||||||
os.getenv("LLM_API_BASE")
|
Config.get("llm_api_base")
|
||||||
or os.getenv("OPENAI_API_BASE")
|
or Config.get("openai_api_base")
|
||||||
or os.getenv("LITELLM_BASE_URL")
|
or Config.get("litellm_base_url")
|
||||||
or os.getenv("OLLAMA_API_BASE")
|
or Config.get("ollama_api_base")
|
||||||
)
|
)
|
||||||
|
|
||||||
test_messages = [
|
test_messages = [
|
||||||
@@ -212,7 +213,7 @@ async def warm_up_llm() -> None:
|
|||||||
{"role": "user", "content": "Reply with just 'OK'."},
|
{"role": "user", "content": "Reply with just 'OK'."},
|
||||||
]
|
]
|
||||||
|
|
||||||
llm_timeout = int(os.getenv("LLM_TIMEOUT", "600"))
|
llm_timeout = int(Config.get("llm_timeout")) # type: ignore[arg-type]
|
||||||
|
|
||||||
completion_kwargs: dict[str, Any] = {
|
completion_kwargs: dict[str, Any] = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
@@ -512,6 +513,8 @@ def main() -> None:
|
|||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
apply_saved_config()
|
||||||
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
check_docker_installed()
|
check_docker_installed()
|
||||||
@@ -520,6 +523,8 @@ def main() -> None:
|
|||||||
validate_environment()
|
validate_environment()
|
||||||
asyncio.run(warm_up_llm())
|
asyncio.run(warm_up_llm())
|
||||||
|
|
||||||
|
save_current_config()
|
||||||
|
|
||||||
args.run_name = generate_run_name(args.targets_info)
|
args.run_name = generate_run_name(args.targets_info)
|
||||||
|
|
||||||
for target_info in args.targets_info:
|
for target_info in args.targets_info:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import os
|
from strix.config import Config
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig:
|
class LLMConfig:
|
||||||
@@ -10,7 +10,7 @@ class LLMConfig:
|
|||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
scan_mode: str = "deep",
|
scan_mode: str = "deep",
|
||||||
):
|
):
|
||||||
self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5")
|
self.model_name = model_name or Config.get("strix_llm")
|
||||||
|
|
||||||
if not self.model_name:
|
if not self.model_name:
|
||||||
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
||||||
@@ -18,6 +18,6 @@ class LLMConfig:
|
|||||||
self.enable_prompt_caching = enable_prompt_caching
|
self.enable_prompt_caching = enable_prompt_caching
|
||||||
self.skills = skills or []
|
self.skills = skills or []
|
||||||
|
|
||||||
self.timeout = timeout or int(os.getenv("LLM_TIMEOUT", "300"))
|
self.timeout = timeout or int(Config.get("llm_timeout")) # type: ignore[arg-type]
|
||||||
|
|
||||||
self.scan_mode = scan_mode if scan_mode in ["quick", "standard", "deep"] else "deep"
|
self.scan_mode = scan_mode if scan_mode in ["quick", "standard", "deep"] else "deep"
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -154,13 +155,13 @@ def check_duplicate(
|
|||||||
|
|
||||||
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
|
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
|
||||||
|
|
||||||
model_name = os.getenv("STRIX_LLM", "openai/gpt-5")
|
model_name = Config.get("strix_llm")
|
||||||
api_key = os.getenv("LLM_API_KEY")
|
api_key = Config.get("llm_api_key")
|
||||||
api_base = (
|
api_base = (
|
||||||
os.getenv("LLM_API_BASE")
|
Config.get("llm_api_base")
|
||||||
or os.getenv("OPENAI_API_BASE")
|
or Config.get("openai_api_base")
|
||||||
or os.getenv("LITELLM_BASE_URL")
|
or Config.get("litellm_base_url")
|
||||||
or os.getenv("OLLAMA_API_BASE")
|
or Config.get("ollama_api_base")
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -16,6 +15,7 @@ from jinja2 import (
|
|||||||
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
|
from litellm import completion_cost, stream_chunk_builder, supports_reasoning
|
||||||
from litellm.utils import supports_prompt_caching, supports_vision
|
from litellm.utils import supports_prompt_caching, supports_vision
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
from strix.llm.config import LLMConfig
|
from strix.llm.config import LLMConfig
|
||||||
from strix.llm.memory_compressor import MemoryCompressor
|
from strix.llm.memory_compressor import MemoryCompressor
|
||||||
from strix.llm.request_queue import get_global_queue
|
from strix.llm.request_queue import get_global_queue
|
||||||
@@ -46,16 +46,14 @@ logger = logging.getLogger(__name__)
|
|||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
litellm.modify_params = True
|
litellm.modify_params = True
|
||||||
|
|
||||||
_LLM_API_KEY = os.getenv("LLM_API_KEY")
|
_LLM_API_KEY = Config.get("llm_api_key")
|
||||||
_LLM_API_BASE = (
|
_LLM_API_BASE = (
|
||||||
os.getenv("LLM_API_BASE")
|
Config.get("llm_api_base")
|
||||||
or os.getenv("OPENAI_API_BASE")
|
or Config.get("openai_api_base")
|
||||||
or os.getenv("LITELLM_BASE_URL")
|
or Config.get("litellm_base_url")
|
||||||
or os.getenv("OLLAMA_API_BASE")
|
or Config.get("ollama_api_base")
|
||||||
)
|
)
|
||||||
_STRIX_REASONING_EFFORT = os.getenv(
|
_STRIX_REASONING_EFFORT = Config.get("strix_reasoning_effort")
|
||||||
"STRIX_REASONING_EFFORT"
|
|
||||||
) # "none", "minimal", "low", "medium", "high", or "xhigh"
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestFailedError(Exception):
|
class LLMRequestFailedError(Exception):
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -150,7 +151,7 @@ class MemoryCompressor:
|
|||||||
timeout: int = 600,
|
timeout: int = 600,
|
||||||
):
|
):
|
||||||
self.max_images = max_images
|
self.max_images = max_images
|
||||||
self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5")
|
self.model_name = model_name or Config.get("strix_llm")
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
if not self.model_name:
|
if not self.model_name:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
@@ -8,16 +7,13 @@ from typing import Any
|
|||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
from litellm.types.utils import ModelResponseStream
|
from litellm.types.utils import ModelResponseStream
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestQueue:
|
class LLMRequestQueue:
|
||||||
def __init__(self, max_concurrent: int = 1, delay_between_requests: float = 4.0):
|
def __init__(self, max_concurrent: int = 1, delay_between_requests: float = 4.0):
|
||||||
rate_limit_delay = os.getenv("LLM_RATE_LIMIT_DELAY")
|
delay_between_requests = float(Config.get("llm_rate_limit_delay")) # type: ignore[arg-type]
|
||||||
if rate_limit_delay:
|
max_concurrent = int(Config.get("llm_rate_limit_concurrent")) # type: ignore[arg-type]
|
||||||
delay_between_requests = float(rate_limit_delay)
|
|
||||||
|
|
||||||
rate_limit_concurrent = os.getenv("LLM_RATE_LIMIT_CONCURRENT")
|
|
||||||
if rate_limit_concurrent:
|
|
||||||
max_concurrent = int(rate_limit_concurrent)
|
|
||||||
|
|
||||||
self.max_concurrent = max_concurrent
|
self.max_concurrent = max_concurrent
|
||||||
self.delay_between_requests = delay_between_requests
|
self.delay_between_requests = delay_between_requests
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import os
|
from strix.config import Config
|
||||||
|
|
||||||
from .runtime import AbstractRuntime
|
from .runtime import AbstractRuntime
|
||||||
|
|
||||||
@@ -13,7 +13,7 @@ class SandboxInitializationError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def get_runtime() -> AbstractRuntime:
|
def get_runtime() -> AbstractRuntime:
|
||||||
runtime_backend = os.getenv("STRIX_RUNTIME_BACKEND", "docker")
|
runtime_backend = Config.get("strix_runtime_backend")
|
||||||
|
|
||||||
if runtime_backend == "docker":
|
if runtime_backend == "docker":
|
||||||
from .docker_runtime import DockerRuntime
|
from .docker_runtime import DockerRuntime
|
||||||
|
|||||||
@@ -15,11 +15,13 @@ from docker.models.containers import Container
|
|||||||
from requests.exceptions import ConnectionError as RequestsConnectionError
|
from requests.exceptions import ConnectionError as RequestsConnectionError
|
||||||
from requests.exceptions import Timeout as RequestsTimeout
|
from requests.exceptions import Timeout as RequestsTimeout
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
|
||||||
from . import SandboxInitializationError
|
from . import SandboxInitializationError
|
||||||
from .runtime import AbstractRuntime, SandboxInfo
|
from .runtime import AbstractRuntime, SandboxInfo
|
||||||
|
|
||||||
|
|
||||||
STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10")
|
STRIX_IMAGE: str = Config.get("strix_image") # type: ignore[assignment]
|
||||||
HOST_GATEWAY_HOSTNAME = "host.docker.internal"
|
HOST_GATEWAY_HOSTNAME = "host.docker.internal"
|
||||||
DOCKER_TIMEOUT = 60 # seconds
|
DOCKER_TIMEOUT = 60 # seconds
|
||||||
TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request
|
TOOL_SERVER_HEALTH_REQUEST_TIMEOUT = 5 # seconds per health check request
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
import urllib.request
|
import urllib.request
|
||||||
@@ -7,6 +6,8 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from strix.telemetry.tracer import Tracer
|
from strix.telemetry.tracer import Tracer
|
||||||
@@ -18,7 +19,7 @@ _SESSION_ID = uuid4().hex[:16]
|
|||||||
|
|
||||||
|
|
||||||
def _is_enabled() -> bool:
|
def _is_enabled() -> bool:
|
||||||
return os.getenv("STRIX_TELEMETRY", "1").lower() not in ("0", "false", "no", "off")
|
return (Config.get("strix_telemetry") or "1").lower() not in ("0", "false", "no", "off")
|
||||||
|
|
||||||
|
|
||||||
def _is_first_run() -> bool:
|
def _is_first_run() -> bool:
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
|
||||||
from .executor import (
|
from .executor import (
|
||||||
execute_tool,
|
execute_tool,
|
||||||
execute_tool_invocation,
|
execute_tool_invocation,
|
||||||
@@ -22,9 +24,9 @@ from .registry import (
|
|||||||
|
|
||||||
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
SANDBOX_MODE = os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "true"
|
||||||
|
|
||||||
HAS_PERPLEXITY_API = bool(os.getenv("PERPLEXITY_API_KEY"))
|
HAS_PERPLEXITY_API = bool(Config.get("perplexity_api_key"))
|
||||||
|
|
||||||
DISABLE_BROWSER = os.getenv("STRIX_DISABLE_BROWSER", "false").lower() == "true"
|
DISABLE_BROWSER = (Config.get("strix_disable_browser") or "false").lower() == "true"
|
||||||
|
|
||||||
if not SANDBOX_MODE:
|
if not SANDBOX_MODE:
|
||||||
from .agents_graph import * # noqa: F403
|
from .agents_graph import * # noqa: F403
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from typing import Any
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from strix.config import Config
|
||||||
|
|
||||||
|
|
||||||
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
|
if os.getenv("STRIX_SANDBOX_MODE", "false").lower() == "false":
|
||||||
from strix.runtime import get_runtime
|
from strix.runtime import get_runtime
|
||||||
@@ -17,8 +19,8 @@ from .registry import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SANDBOX_EXECUTION_TIMEOUT = float(os.getenv("STRIX_SANDBOX_EXECUTION_TIMEOUT", "500"))
|
SANDBOX_EXECUTION_TIMEOUT = float(Config.get("strix_sandbox_execution_timeout")) # type: ignore[arg-type]
|
||||||
SANDBOX_CONNECT_TIMEOUT = float(os.getenv("STRIX_SANDBOX_CONNECT_TIMEOUT", "10"))
|
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout")) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any:
|
async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any:
|
||||||
|
|||||||
Reference in New Issue
Block a user