diff --git a/src/nomadic/util/cli.py b/src/nomadic/util/cli.py index 9bf770d..da57032 100644 --- a/src/nomadic/util/cli.py +++ b/src/nomadic/util/cli.py @@ -11,7 +11,12 @@ find_workspace_root, check_if_workspace_root, ) -from nomadic.util.config import load_config, default_config_path +from nomadic.util.config import ( + InvalidConfigError, + get_command_defaults, + load_config, + default_config_path, +) WORKSPACE_OPTION_KEY = "workspace" @@ -122,8 +127,24 @@ def load_defaults_from_config(ctx: click.Context, command: Optional[str] = None) config_path = os.path.join(workspace.path, default_config_path) if os.path.isfile(config_path): config = load_config(config_path) - defaults = config.get("defaults", {}) - defaults = defaults | config.get(command, {}).get("defaults", {}) + if not config: + # When empty dict or non, the config is empty + return + if not isinstance(config, dict): + if not ctx.resilient_parsing: + raise click.UsageError( + f"Invalid config at {config_path}: config is not a dict." + ) + else: + return + try: + defaults = get_command_defaults(config, command) + except InvalidConfigError as e: + if not ctx.resilient_parsing: + raise click.UsageError(f"Invalid config at {config_path}: {e}") + else: + return + if defaults: if not ctx.resilient_parsing: # Don't print defaults if parsing is resilient, as this is used for shell completion diff --git a/src/nomadic/util/config.py b/src/nomadic/util/config.py index 6489e52..f83d1bb 100644 --- a/src/nomadic/util/config.py +++ b/src/nomadic/util/config.py @@ -1,4 +1,5 @@ from functools import reduce +from typing import Optional from yaml import dump, load @@ -11,6 +12,12 @@ default_config_path = ".config.yaml" +class InvalidConfigError(Exception): + """Raised when the config is invalid""" + + pass + + def load_config(config_path: str) -> dict: """ Load configuration from a YAML file. @@ -62,3 +69,25 @@ def get_config_value(d: dict, keys: list, default=None): reduce(lambda acc, k: acc.get(k, {}) if isinstance(acc, dict) else {}, keys, d) or default ) + + +def get_command_defaults(config: dict, command: Optional[str]) -> dict: + defaults = must_get_dict(config, "defaults", "defaults should be a dict") + if command is not None: + command_config = must_get_dict( + config, command, f"{command} config should be a dict" + ) + command_defaults = must_get_dict( + command_config, "defaults", f"{command} defaults should be a dict" + ) + defaults = defaults | command_defaults + return defaults + + +def must_get_dict(d: dict, key: str, message: str) -> dict: + value = d.get(key, {}) + if value is None: + value = {} + if not isinstance(value, dict): + raise InvalidConfigError(message) + return value diff --git a/src/nomadic/util/config_test.py b/src/nomadic/util/config_test.py index c5243e5..0cd0b48 100644 --- a/src/nomadic/util/config_test.py +++ b/src/nomadic/util/config_test.py @@ -1,6 +1,11 @@ import pytest -from nomadic.util.config import get_config_value, set_config_value +from nomadic.util.config import ( + InvalidConfigError, + get_command_defaults, + get_config_value, + set_config_value, +) @pytest.mark.parametrize( @@ -28,3 +33,33 @@ def test_set_config_value_table(keys, value, expected): set_config_value(d, keys, value) assert get_config_value(d, keys) == value assert d == expected + + +@pytest.mark.parametrize( + "config, command, expected", + [ + ({}, None, {}), + ({}, "test", {}), + ({"defaults": None}, "test", {}), + ({"defaults": {"foo": "bar"}}, "test", {"foo": "bar"}), + ( + {"defaults": {"foo": "bar"}, "test": {"defaults": {"this": "that"}}}, + "test", + {"foo": "bar", "this": "that"}, + ), + ], +) +def test_get_command_defaults(config, command, expected): + assert get_command_defaults(config, command) == expected + + +@pytest.mark.parametrize( + "config, command", + [ + ({"defaults": 5}, None), + ({"defaults": {}, "test": 5}, "test"), + ], +) +def test_get_command_defaults_invalid_configs(config, command): + with pytest.raises(InvalidConfigError): + get_command_defaults(config, command)