Fix up loading/storing configs + unit tests

This commit is contained in:
Patrick Robertson
2025-01-23 20:32:19 +01:00
parent 65ef46d01e
commit b27bf8ffeb
5 changed files with 228 additions and 60 deletions

View File

@@ -6,12 +6,17 @@ flexible setup in various environments.
"""
import argparse
import yaml
from ruamel.yaml import YAML, CommentedMap
from ruamel.yaml.comments import CommentedMap
from dataclasses import dataclass, field
from collections import OrderedDict
from collections.abc import Iterable
from copy import deepcopy
from .loader import MODULE_TYPES
from typing import Any, List
# configurable_parents = [
# Feeder,
# Enricher,
@@ -50,21 +55,16 @@ from .loader import MODULE_TYPES
# parser.add_argument('--config', action='store', dest='config', help='the filename of the YAML configuration file (defaults to \'config.yaml\')', default='orchestration.yaml')
# parser.add_argument('--version', action='version', version=__version__)
EMPTY_CONFIG = {
EMPTY_CONFIG = CommentedMap(**{
"steps": dict((f"{module_type}s", []) for module_type in MODULE_TYPES)
}
class LoadFromFile (argparse.Action):
def __call__ (self, parser, namespace, values, option_string = None):
with values as f:
# parse arguments in the file and store them in the target namespace
parser.parse_args(f.read().split(), namespace)
})
def to_dot_notation(yaml_conf: str) -> argparse.ArgumentParser:
def to_dot_notation(yaml_conf: CommentedMap | dict) -> argparse.ArgumentParser:
dotdict = {}
def process_subdict(subdict, prefix=""):
for key, value in subdict.items():
if type(value) == dict:
if is_dict_type(value):
process_subdict(value, f"{prefix}{key}.")
else:
dotdict[f"{prefix}{key}"] = value
@@ -72,31 +72,64 @@ def to_dot_notation(yaml_conf: str) -> argparse.ArgumentParser:
process_subdict(yaml_conf)
return dotdict
def merge_dicts(dotdict, yaml_dict):
def process_subdict(subdict, prefix=""):
for key, value in subdict.items():
if "." in key:
keys = key.split(".")
subdict = yaml_dict
for k in keys[:-1]:
subdict = subdict.setdefault(k, {})
subdict[keys[-1]] = value
else:
yaml_dict[key] = value
def from_dot_notation(dotdict: dict) -> dict:
normal_dict = {}
def add_part(key, value, current_dict):
if "." in key:
key_parts = key.split(".")
current_dict.setdefault(key_parts[0], {})
add_part(".".join(key_parts[1:]), value, current_dict[key_parts[0]])
else:
current_dict[key] = value
for key, value in dotdict.items():
add_part(key, value, normal_dict)
return normal_dict
def is_list_type(value):
return isinstance(value, list) or isinstance(value, tuple) or isinstance(value, set)
def is_dict_type(value):
return isinstance(value, dict) or isinstance(value, CommentedMap)
def merge_dicts(dotdict: dict, yaml_dict: CommentedMap) -> CommentedMap:
yaml_dict: CommentedMap = deepcopy(yaml_dict)
# first deal with lists, since 'update' replaces lists from a in b, but we want to extend
def update_dict(subdict, yaml_subdict):
for key, value in yaml_subdict.items():
if not subdict.get(key):
continue
if is_dict_type(value):
update_dict(subdict[key], value)
elif is_list_type(value):
yaml_subdict[key].extend(s for s in subdict[key] if s not in yaml_subdict[key])
else:
yaml_subdict[key] = subdict[key]
update_dict(from_dot_notation(dotdict), yaml_dict)
process_subdict(dotdict)
return yaml_dict
def read_yaml(yaml_filename: str) -> dict:
yaml = YAML()
def read_yaml(yaml_filename: str) -> CommentedMap:
config = None
try:
with open(yaml_filename, "r", encoding="utf-8") as inf:
config = yaml.safe_load(inf)
config = yaml.load(inf)
except FileNotFoundError:
config = EMPTY_CONFIG
pass
if not config:
config = EMPTY_CONFIG
return config
def store_yaml(config: dict, yaml_filename: str):
def store_yaml(config: CommentedMap, yaml_filename: str):
with open(yaml_filename, "w", encoding="utf-8") as outf:
yaml.dump(config, outf, default_flow_style=False)
yaml.dump(config, outf)