refactor: simplify --config implementation to reuse existing config system
- Reuse apply_saved() instead of custom override logic - Add force parameter to override existing env vars - Move validation to utils.py - Prevent saving when using custom config (one-time override) - Fix: don't modify ~/.strix/cli-config.json when --config is used Co-Authored-By: FeedClogger <feedclogger@users.noreply.github.com>
This commit is contained in:
@@ -45,6 +45,9 @@ class Config:
|
||||
# Telemetry
|
||||
strix_telemetry = "1"
|
||||
|
||||
# Config file override (set via --config CLI arg)
|
||||
_config_file_override: Path | None = None
|
||||
|
||||
@classmethod
|
||||
def _tracked_names(cls) -> list[str]:
|
||||
return [
|
||||
@@ -83,6 +86,8 @@ class Config:
|
||||
|
||||
@classmethod
|
||||
def config_file(cls) -> Path:
|
||||
if cls._config_file_override is not None:
|
||||
return cls._config_file_override
|
||||
return cls.config_dir() / "cli-config.json"
|
||||
|
||||
@classmethod
|
||||
@@ -101,7 +106,7 @@ class Config:
|
||||
def save(cls, config: dict[str, Any]) -> bool:
|
||||
try:
|
||||
cls.config_dir().mkdir(parents=True, exist_ok=True)
|
||||
config_path = cls.config_file()
|
||||
config_path = cls.config_dir() / "cli-config.json"
|
||||
with config_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
except OSError:
|
||||
@@ -111,7 +116,7 @@ class Config:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def apply_saved(cls) -> dict[str, str]:
|
||||
def apply_saved(cls, force: bool = False) -> dict[str, str]:
|
||||
saved = cls.load()
|
||||
env_vars = saved.get("env", {})
|
||||
if not isinstance(env_vars, dict):
|
||||
@@ -124,15 +129,17 @@ class Config:
|
||||
if cleared_vars:
|
||||
for var_name in cleared_vars:
|
||||
env_vars.pop(var_name, None)
|
||||
if cls._config_file_override is None:
|
||||
cls.save({"env": env_vars})
|
||||
if cls._llm_env_changed(env_vars):
|
||||
for var_name in cls._llm_env_vars():
|
||||
env_vars.pop(var_name, None)
|
||||
if cls._config_file_override is None:
|
||||
cls.save({"env": env_vars})
|
||||
applied = {}
|
||||
|
||||
for var_name, var_value in env_vars.items():
|
||||
if var_name in cls.tracked_vars() and var_name not in os.environ:
|
||||
if var_name in cls.tracked_vars() and (force or var_name not in os.environ):
|
||||
os.environ[var_name] = var_value
|
||||
applied[var_name] = var_value
|
||||
|
||||
@@ -163,17 +170,9 @@ class Config:
|
||||
|
||||
return cls.save({"env": merged})
|
||||
|
||||
@classmethod
|
||||
def override(cls, key: str, value: str) -> None:
|
||||
"""Override a configuration variable dynamically."""
|
||||
if hasattr(cls, key):
|
||||
setattr(cls, key, value)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
def apply_saved_config() -> dict[str, str]:
|
||||
return Config.apply_saved()
|
||||
def apply_saved_config(force: bool = False) -> dict[str, str]:
|
||||
return Config.apply_saved(force=force)
|
||||
|
||||
|
||||
def save_current_config() -> bool:
|
||||
|
||||
@@ -5,7 +5,6 @@ Strix Agent Interface
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
@@ -36,6 +35,7 @@ from strix.interface.utils import ( # noqa: E402
|
||||
infer_target_type,
|
||||
process_pull_line,
|
||||
rewrite_localhost_targets,
|
||||
validate_config_file,
|
||||
validate_llm_response,
|
||||
)
|
||||
from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402
|
||||
@@ -360,7 +360,7 @@ Examples:
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to the configuration file (.json)",
|
||||
help="Path to a custom config file (JSON) to use instead of ~/.strix/cli-config.json",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -507,19 +507,14 @@ def pull_docker_image() -> None:
|
||||
console.print()
|
||||
|
||||
|
||||
def load_config_file(config_path: str):
|
||||
if config_path.endswith(".json"):
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
for key, value in config_data.items():
|
||||
Config.override(key, str(value)) # Use Config class to override variables
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"Error loading JSON config file: {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("Unsupported config file format. Use .json.")
|
||||
sys.exit(1)
|
||||
def apply_config_override(config_path: str) -> None:
|
||||
Config._config_file_override = validate_config_file(config_path)
|
||||
apply_saved_config(force=True)
|
||||
|
||||
|
||||
def persist_config() -> None:
|
||||
if Config._config_file_override is None:
|
||||
save_current_config()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@@ -529,7 +524,7 @@ def main() -> None:
|
||||
args = parse_arguments()
|
||||
|
||||
if args.config:
|
||||
load_config_file(args.config)
|
||||
apply_config_override(args.config)
|
||||
|
||||
check_docker_installed()
|
||||
pull_docker_image()
|
||||
@@ -537,7 +532,7 @@ def main() -> None:
|
||||
validate_environment()
|
||||
asyncio.run(warm_up_llm())
|
||||
|
||||
save_current_config()
|
||||
persist_config()
|
||||
|
||||
args.run_name = generate_run_name(args.targets_info)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import ipaddress
|
||||
import json
|
||||
import re
|
||||
import secrets
|
||||
import shutil
|
||||
@@ -789,3 +790,33 @@ def process_pull_line(
|
||||
def validate_llm_response(response: Any) -> None:
|
||||
if not response or not response.choices or not response.choices[0].message.content:
|
||||
raise RuntimeError("Invalid response from LLM")
|
||||
|
||||
|
||||
def validate_config_file(config_path: str) -> Path:
|
||||
console = Console()
|
||||
path = Path(config_path)
|
||||
|
||||
if not path.exists():
|
||||
console.print(f"[bold red]Error:[/] Config file not found: {config_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if path.suffix != ".json":
|
||||
console.print("[bold red]Error:[/] Config file must be a .json file")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
console.print(f"[bold red]Error:[/] Invalid JSON in config file: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
console.print("[bold red]Error:[/] Config file must contain a JSON object")
|
||||
sys.exit(1)
|
||||
|
||||
if "env" not in data or not isinstance(data.get("env"), dict):
|
||||
console.print("[bold red]Error:[/] Config file must have an 'env' object")
|
||||
sys.exit(1)
|
||||
|
||||
return path
|
||||
|
||||
Reference in New Issue
Block a user