diff --git a/strix/config/config.py b/strix/config/config.py index a602658..636c8bf 100644 --- a/strix/config/config.py +++ b/strix/config/config.py @@ -19,6 +19,18 @@ class Config: strix_llm_max_retries = "5" strix_memory_compressor_timeout = "30" llm_timeout = "300" + _LLM_CANONICAL_NAMES = ( + "strix_llm", + "llm_api_key", + "llm_api_base", + "openai_api_base", + "litellm_base_url", + "ollama_api_base", + "strix_reasoning_effort", + "strix_llm_max_retries", + "strix_memory_compressor_timeout", + "llm_timeout", + ) # Tool & Feature Configuration perplexity_api_key = None @@ -45,6 +57,20 @@ class Config: def tracked_vars(cls) -> list[str]: return [name.upper() for name in cls._tracked_names()] + @classmethod + def _llm_env_vars(cls) -> set[str]: + return {name.upper() for name in cls._LLM_CANONICAL_NAMES} + + @classmethod + def _llm_env_changed(cls, saved_env: dict[str, Any]) -> bool: + for var_name in cls._llm_env_vars(): + current = os.getenv(var_name) + if current is None: + continue + if saved_env.get(var_name) != current: + return True + return False + @classmethod def get(cls, name: str) -> str | None: env_name = name.upper() @@ -88,10 +114,24 @@ class Config: def apply_saved(cls) -> dict[str, str]: saved = cls.load() env_vars = saved.get("env", {}) + if not isinstance(env_vars, dict): + env_vars = {} + cleared_vars = { + var_name + for var_name in cls.tracked_vars() + if var_name in os.environ and os.environ.get(var_name) == "" + } + if cleared_vars: + for var_name in cleared_vars: + env_vars.pop(var_name, None) + cls.save({"env": env_vars}) + if cls._llm_env_changed(env_vars): + cls.save({"env": {}}) + return {} applied = {} for var_name, var_value in env_vars.items(): - if var_name in cls.tracked_vars() and not os.getenv(var_name): + if var_name in cls.tracked_vars() and var_name not in os.environ: os.environ[var_name] = var_value applied[var_name] = var_value