refactor: simplify --config implementation to reuse existing config system
- Reuse apply_saved() instead of custom override logic - Add force parameter to override existing env vars - Move validation to utils.py - Prevent saving when using custom config (one-time override) - Fix: don't modify ~/.strix/cli-config.json when --config is used Co-Authored-By: FeedClogger <feedclogger@users.noreply.github.com>
This commit is contained in:
@@ -45,6 +45,9 @@ class Config:
|
|||||||
# Telemetry
|
# Telemetry
|
||||||
strix_telemetry = "1"
|
strix_telemetry = "1"
|
||||||
|
|
||||||
|
# Config file override (set via --config CLI arg)
|
||||||
|
_config_file_override: Path | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _tracked_names(cls) -> list[str]:
|
def _tracked_names(cls) -> list[str]:
|
||||||
return [
|
return [
|
||||||
@@ -83,6 +86,8 @@ class Config:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def config_file(cls) -> Path:
|
def config_file(cls) -> Path:
|
||||||
|
if cls._config_file_override is not None:
|
||||||
|
return cls._config_file_override
|
||||||
return cls.config_dir() / "cli-config.json"
|
return cls.config_dir() / "cli-config.json"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -101,7 +106,7 @@ class Config:
|
|||||||
def save(cls, config: dict[str, Any]) -> bool:
|
def save(cls, config: dict[str, Any]) -> bool:
|
||||||
try:
|
try:
|
||||||
cls.config_dir().mkdir(parents=True, exist_ok=True)
|
cls.config_dir().mkdir(parents=True, exist_ok=True)
|
||||||
config_path = cls.config_file()
|
config_path = cls.config_dir() / "cli-config.json"
|
||||||
with config_path.open("w", encoding="utf-8") as f:
|
with config_path.open("w", encoding="utf-8") as f:
|
||||||
json.dump(config, f, indent=2)
|
json.dump(config, f, indent=2)
|
||||||
except OSError:
|
except OSError:
|
||||||
@@ -111,7 +116,7 @@ class Config:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply_saved(cls) -> dict[str, str]:
|
def apply_saved(cls, force: bool = False) -> dict[str, str]:
|
||||||
saved = cls.load()
|
saved = cls.load()
|
||||||
env_vars = saved.get("env", {})
|
env_vars = saved.get("env", {})
|
||||||
if not isinstance(env_vars, dict):
|
if not isinstance(env_vars, dict):
|
||||||
@@ -124,15 +129,17 @@ class Config:
|
|||||||
if cleared_vars:
|
if cleared_vars:
|
||||||
for var_name in cleared_vars:
|
for var_name in cleared_vars:
|
||||||
env_vars.pop(var_name, None)
|
env_vars.pop(var_name, None)
|
||||||
cls.save({"env": env_vars})
|
if cls._config_file_override is None:
|
||||||
|
cls.save({"env": env_vars})
|
||||||
if cls._llm_env_changed(env_vars):
|
if cls._llm_env_changed(env_vars):
|
||||||
for var_name in cls._llm_env_vars():
|
for var_name in cls._llm_env_vars():
|
||||||
env_vars.pop(var_name, None)
|
env_vars.pop(var_name, None)
|
||||||
cls.save({"env": env_vars})
|
if cls._config_file_override is None:
|
||||||
|
cls.save({"env": env_vars})
|
||||||
applied = {}
|
applied = {}
|
||||||
|
|
||||||
for var_name, var_value in env_vars.items():
|
for var_name, var_value in env_vars.items():
|
||||||
if var_name in cls.tracked_vars() and var_name not in os.environ:
|
if var_name in cls.tracked_vars() and (force or var_name not in os.environ):
|
||||||
os.environ[var_name] = var_value
|
os.environ[var_name] = var_value
|
||||||
applied[var_name] = var_value
|
applied[var_name] = var_value
|
||||||
|
|
||||||
@@ -163,17 +170,9 @@ class Config:
|
|||||||
|
|
||||||
return cls.save({"env": merged})
|
return cls.save({"env": merged})
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def override(cls, key: str, value: str) -> None:
|
|
||||||
"""Override a configuration variable dynamically."""
|
|
||||||
if hasattr(cls, key):
|
|
||||||
setattr(cls, key, value)
|
|
||||||
else:
|
|
||||||
os.environ[key] = value
|
|
||||||
|
|
||||||
|
def apply_saved_config(force: bool = False) -> dict[str, str]:
|
||||||
def apply_saved_config() -> dict[str, str]:
|
return Config.apply_saved(force=force)
|
||||||
return Config.apply_saved()
|
|
||||||
|
|
||||||
|
|
||||||
def save_current_config() -> bool:
|
def save_current_config() -> bool:
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ Strix Agent Interface
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
@@ -36,6 +35,7 @@ from strix.interface.utils import ( # noqa: E402
|
|||||||
infer_target_type,
|
infer_target_type,
|
||||||
process_pull_line,
|
process_pull_line,
|
||||||
rewrite_localhost_targets,
|
rewrite_localhost_targets,
|
||||||
|
validate_config_file,
|
||||||
validate_llm_response,
|
validate_llm_response,
|
||||||
)
|
)
|
||||||
from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402
|
from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402
|
||||||
@@ -360,7 +360,7 @@ Examples:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config",
|
"--config",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to the configuration file (.json)",
|
help="Path to a custom config file (JSON) to use instead of ~/.strix/cli-config.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -507,19 +507,14 @@ def pull_docker_image() -> None:
|
|||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
|
|
||||||
def load_config_file(config_path: str):
|
def apply_config_override(config_path: str) -> None:
|
||||||
if config_path.endswith(".json"):
|
Config._config_file_override = validate_config_file(config_path)
|
||||||
try:
|
apply_saved_config(force=True)
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
config_data = json.load(f)
|
|
||||||
for key, value in config_data.items():
|
def persist_config() -> None:
|
||||||
Config.override(key, str(value)) # Use Config class to override variables
|
if Config._config_file_override is None:
|
||||||
except (json.JSONDecodeError, OSError) as e:
|
save_current_config()
|
||||||
print(f"Error loading JSON config file: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
print("Unsupported config file format. Use .json.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
@@ -529,7 +524,7 @@ def main() -> None:
|
|||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
if args.config:
|
if args.config:
|
||||||
load_config_file(args.config)
|
apply_config_override(args.config)
|
||||||
|
|
||||||
check_docker_installed()
|
check_docker_installed()
|
||||||
pull_docker_image()
|
pull_docker_image()
|
||||||
@@ -537,7 +532,7 @@ def main() -> None:
|
|||||||
validate_environment()
|
validate_environment()
|
||||||
asyncio.run(warm_up_llm())
|
asyncio.run(warm_up_llm())
|
||||||
|
|
||||||
save_current_config()
|
persist_config()
|
||||||
|
|
||||||
args.run_name = generate_run_name(args.targets_info)
|
args.run_name = generate_run_name(args.targets_info)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import ipaddress
|
import ipaddress
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import shutil
|
import shutil
|
||||||
@@ -789,3 +790,33 @@ def process_pull_line(
|
|||||||
def validate_llm_response(response: Any) -> None:
|
def validate_llm_response(response: Any) -> None:
|
||||||
if not response or not response.choices or not response.choices[0].message.content:
|
if not response or not response.choices or not response.choices[0].message.content:
|
||||||
raise RuntimeError("Invalid response from LLM")
|
raise RuntimeError("Invalid response from LLM")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config_file(config_path: str) -> Path:
|
||||||
|
console = Console()
|
||||||
|
path = Path(config_path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
console.print(f"[bold red]Error:[/] Config file not found: {config_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if path.suffix != ".json":
|
||||||
|
console.print("[bold red]Error:[/] Config file must be a .json file")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
console.print(f"[bold red]Error:[/] Invalid JSON in config file: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
console.print("[bold red]Error:[/] Config file must contain a JSON object")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if "env" not in data or not isinstance(data.get("env"), dict):
|
||||||
|
console.print("[bold red]Error:[/] Config file must have an 'env' object")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|||||||
Reference in New Issue
Block a user