refactor: simplify --config implementation to reuse existing config system

- Reuse apply_saved() instead of custom override logic
- Add force parameter to override existing env vars
- Move validation to utils.py
- Prevent saving when using custom config (one-time override)
- Fix: don't modify ~/.strix/cli-config.json when --config is used

Co-Authored-By: FeedClogger <feedclogger@users.noreply.github.com>
This commit is contained in:
0xallam
2026-01-20 16:39:35 -08:00
committed by Ahmed Allam
parent 4ab9af6e47
commit 165887798d
3 changed files with 57 additions and 32 deletions

View File

@@ -45,6 +45,9 @@ class Config:
# Telemetry # Telemetry
strix_telemetry = "1" strix_telemetry = "1"
# Config file override (set via --config CLI arg)
_config_file_override: Path | None = None
@classmethod @classmethod
def _tracked_names(cls) -> list[str]: def _tracked_names(cls) -> list[str]:
return [ return [
@@ -83,6 +86,8 @@ class Config:
@classmethod @classmethod
def config_file(cls) -> Path: def config_file(cls) -> Path:
if cls._config_file_override is not None:
return cls._config_file_override
return cls.config_dir() / "cli-config.json" return cls.config_dir() / "cli-config.json"
@classmethod @classmethod
@@ -101,7 +106,7 @@ class Config:
def save(cls, config: dict[str, Any]) -> bool: def save(cls, config: dict[str, Any]) -> bool:
try: try:
cls.config_dir().mkdir(parents=True, exist_ok=True) cls.config_dir().mkdir(parents=True, exist_ok=True)
config_path = cls.config_file() config_path = cls.config_dir() / "cli-config.json"
with config_path.open("w", encoding="utf-8") as f: with config_path.open("w", encoding="utf-8") as f:
json.dump(config, f, indent=2) json.dump(config, f, indent=2)
except OSError: except OSError:
@@ -111,7 +116,7 @@ class Config:
return True return True
@classmethod @classmethod
def apply_saved(cls) -> dict[str, str]: def apply_saved(cls, force: bool = False) -> dict[str, str]:
saved = cls.load() saved = cls.load()
env_vars = saved.get("env", {}) env_vars = saved.get("env", {})
if not isinstance(env_vars, dict): if not isinstance(env_vars, dict):
@@ -124,15 +129,17 @@ class Config:
if cleared_vars: if cleared_vars:
for var_name in cleared_vars: for var_name in cleared_vars:
env_vars.pop(var_name, None) env_vars.pop(var_name, None)
cls.save({"env": env_vars}) if cls._config_file_override is None:
cls.save({"env": env_vars})
if cls._llm_env_changed(env_vars): if cls._llm_env_changed(env_vars):
for var_name in cls._llm_env_vars(): for var_name in cls._llm_env_vars():
env_vars.pop(var_name, None) env_vars.pop(var_name, None)
cls.save({"env": env_vars}) if cls._config_file_override is None:
cls.save({"env": env_vars})
applied = {} applied = {}
for var_name, var_value in env_vars.items(): for var_name, var_value in env_vars.items():
if var_name in cls.tracked_vars() and var_name not in os.environ: if var_name in cls.tracked_vars() and (force or var_name not in os.environ):
os.environ[var_name] = var_value os.environ[var_name] = var_value
applied[var_name] = var_value applied[var_name] = var_value
@@ -163,17 +170,9 @@ class Config:
return cls.save({"env": merged}) return cls.save({"env": merged})
@classmethod
def override(cls, key: str, value: str) -> None:
"""Override a configuration variable dynamically."""
if hasattr(cls, key):
setattr(cls, key, value)
else:
os.environ[key] = value
def apply_saved_config(force: bool = False) -> dict[str, str]:
def apply_saved_config() -> dict[str, str]: return Config.apply_saved(force=force)
return Config.apply_saved()
def save_current_config() -> bool: def save_current_config() -> bool:

View File

@@ -5,7 +5,6 @@ Strix Agent Interface
import argparse import argparse
import asyncio import asyncio
import json
import logging import logging
import shutil import shutil
import sys import sys
@@ -36,6 +35,7 @@ from strix.interface.utils import ( # noqa: E402
infer_target_type, infer_target_type,
process_pull_line, process_pull_line,
rewrite_localhost_targets, rewrite_localhost_targets,
validate_config_file,
validate_llm_response, validate_llm_response,
) )
from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402 from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402
@@ -360,7 +360,7 @@ Examples:
parser.add_argument( parser.add_argument(
"--config", "--config",
type=str, type=str,
help="Path to the configuration file (.json)", help="Path to a custom config file (JSON) to use instead of ~/.strix/cli-config.json",
) )
args = parser.parse_args() args = parser.parse_args()
@@ -507,19 +507,14 @@ def pull_docker_image() -> None:
console.print() console.print()
def load_config_file(config_path: str): def apply_config_override(config_path: str) -> None:
if config_path.endswith(".json"): Config._config_file_override = validate_config_file(config_path)
try: apply_saved_config(force=True)
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
for key, value in config_data.items(): def persist_config() -> None:
Config.override(key, str(value)) # Use Config class to override variables if Config._config_file_override is None:
except (json.JSONDecodeError, OSError) as e: save_current_config()
print(f"Error loading JSON config file: {e}")
sys.exit(1)
else:
print("Unsupported config file format. Use .json.")
sys.exit(1)
def main() -> None: def main() -> None:
@@ -529,7 +524,7 @@ def main() -> None:
args = parse_arguments() args = parse_arguments()
if args.config: if args.config:
load_config_file(args.config) apply_config_override(args.config)
check_docker_installed() check_docker_installed()
pull_docker_image() pull_docker_image()
@@ -537,7 +532,7 @@ def main() -> None:
validate_environment() validate_environment()
asyncio.run(warm_up_llm()) asyncio.run(warm_up_llm())
save_current_config() persist_config()
args.run_name = generate_run_name(args.targets_info) args.run_name = generate_run_name(args.targets_info)

View File

@@ -1,4 +1,5 @@
import ipaddress import ipaddress
import json
import re import re
import secrets import secrets
import shutil import shutil
@@ -789,3 +790,33 @@ def process_pull_line(
def validate_llm_response(response: Any) -> None: def validate_llm_response(response: Any) -> None:
if not response or not response.choices or not response.choices[0].message.content: if not response or not response.choices or not response.choices[0].message.content:
raise RuntimeError("Invalid response from LLM") raise RuntimeError("Invalid response from LLM")
def validate_config_file(config_path: str) -> Path:
console = Console()
path = Path(config_path)
if not path.exists():
console.print(f"[bold red]Error:[/] Config file not found: {config_path}")
sys.exit(1)
if path.suffix != ".json":
console.print("[bold red]Error:[/] Config file must be a .json file")
sys.exit(1)
try:
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
except json.JSONDecodeError as e:
console.print(f"[bold red]Error:[/] Invalid JSON in config file: {e}")
sys.exit(1)
if not isinstance(data, dict):
console.print("[bold red]Error:[/] Config file must contain a JSON object")
sys.exit(1)
if "env" not in data or not isinstance(data.get("env"), dict):
console.print("[bold red]Error:[/] Config file must have an 'env' object")
sys.exit(1)
return path