diff --git a/argparse_dataclass.py b/argparse_dataclass.py index c530f09..3564800 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -239,6 +239,7 @@ """ import argparse +from collections import namedtuple from argparse import BooleanOptionalAction from typing import ( TypeVar, @@ -274,7 +275,7 @@ def parse_args(options_class: Type[OptionsType], args: ArgsType = None) -> OptionsType: """Parse arguments and return as the dataclass type.""" parser = argparse.ArgumentParser() - _add_dataclass_options(options_class, parser) + add_dataclass_options(options_class, parser) kwargs = _get_kwargs(parser.parse_args(args)) return options_class(**kwargs) @@ -286,102 +287,113 @@ def parse_known_args( and list of remaining arguments. """ parser = argparse.ArgumentParser() - _add_dataclass_options(options_class, parser) + add_dataclass_options(options_class, parser) namespace, others = parser.parse_known_args(args=args) kwargs = _get_kwargs(namespace) return options_class(**kwargs), others -def _add_dataclass_options( - options_class: Type[OptionsType], parser: argparse.ArgumentParser -) -> None: - if not is_dataclass(options_class): - raise TypeError("cls must be a dataclass") - - for field in fields(options_class): - args = field.metadata.get("args", [f"--{_get_arg_name(field)}"]) - positional = not args[0].startswith("-") - kwargs = { - "type": field.metadata.get("type", field.type), - "help": field.metadata.get("help", None), - } - - if field.metadata.get("args") and not positional: - # We want to ensure that we store the argument based on the - # name of the field and not whatever flag name was provided - kwargs["dest"] = field.name +def extract_argparse_kwargs(field: Field[Any]) -> Tuple[List[str], Dict[str, Any]]: + """Extract kwargs of ArgumentParser.add_argument from a dataclass field. + Returns pair of (args, kwargs) to be passed to ArgumentParser.add_argument. + """ + args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"]) + positional = not args[0].startswith("-") + kwargs = { + "type": field.metadata.get("type", field.type), + "help": field.metadata.get("help", None), + } + + if field.metadata.get("args") and not positional: + # We want to ensure that we store the argument based on the + # name of the field and not whatever flag name was provided + kwargs["dest"] = field.name + + if field.metadata.get("choices") is not None: + kwargs["choices"] = field.metadata["choices"] + + # Support Literal types as an alternative means of specifying choices. + if get_origin(field.type) is Literal: + # Prohibit a potential collision with the choices field if field.metadata.get("choices") is not None: - kwargs["choices"] = field.metadata["choices"] - - # Support Literal types as an alternative means of specifying choices. - if get_origin(field.type) is Literal: - # Prohibit a potential collision with the choices field - if field.metadata.get("choices") is not None: + raise ValueError( + f"Cannot infer type of items in field: {field.name}. " + "Literal type arguments should not be combined with choices in the metadata. " + "Remove the redundant choices field from the metadata." + ) + + # Get the types of the arguments of the Literal + types = [type(arg) for arg in get_args(field.type)] + + # Make sure just a single type has been used + if len(set(types)) > 1: + raise ValueError( + f"Cannot infer type of items in field: {field.name}. " + "Literal type arguments should contain choices of a single type. " + f"Instead, {len(set(types))} types where found: " + + ", ".join([type_.__name__ for type_ in set(types)]) + + "." + ) + + # Overwrite the type kwarg + kwargs["type"] = types[0] + # Use the literal arguments as choices + kwargs["choices"] = get_args(field.type) + + if field.metadata.get("metavar") is not None: + kwargs["metavar"] = field.metadata["metavar"] + + if field.metadata.get("nargs") is not None: + kwargs["nargs"] = field.metadata["nargs"] + if field.metadata.get("type") is None: + # When nargs is specified, field.type should be a list, + # or something equivalent, like typing.List. + # Using it would most likely result in an error, so if the user + # did not specify the type of the elements within the list, we + # try to infer it: + try: + kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple + except IndexError: + # get_args returned an empty tuple, type cannot be inferred raise ValueError( f"Cannot infer type of items in field: {field.name}. " - "Literal type arguments should not be combined with choices in the metadata. " - "Remove the redundant choices field from the metadata." + "Try using a parameterized type hint, or " + "specifying the type explicitly using metadata['type']" ) - # Get the types of the arguments of the Literal - types = [type(arg) for arg in get_args(field.type)] - - # Make sure just a single type has been used - if len(set(types)) > 1: - raise ValueError( - f"Cannot infer type of items in field: {field.name}. " - "Literal type arguments should contain choices of a single type. " - f"Instead, {len(set(types))} types where found: " - + ", ".join([type_.__name__ for type_ in set(types)]) - + "." + if field.default == field.default_factory == MISSING and not positional: + kwargs["required"] = True + else: + kwargs["default"] = MISSING + + if field.type is bool: + _handle_bool_type(field, args, kwargs) + elif get_origin(field.type) is Union: + if field.metadata.get("type") is None: + # Optional[X] is equivalent to Union[X, None]. + f_args = get_args(field.type) + if len(f_args) == 2 and NoneType in f_args: + arg = next(a for a in f_args if a is not NoneType) + kwargs["type"] = arg + else: + raise TypeError( + "For Union types other than 'Optional', a custom 'type' must be specified using " + "'metadata'." ) - # Overwrite the type kwarg - kwargs["type"] = types[0] - # Use the literal arguments as choices - kwargs["choices"] = get_args(field.type) - - if field.metadata.get("metavar") is not None: - kwargs["metavar"] = field.metadata["metavar"] - - if field.metadata.get("nargs") is not None: - kwargs["nargs"] = field.metadata["nargs"] - if field.metadata.get("type") is None: - # When nargs is specified, field.type should be a list, - # or something equivalent, like typing.List. - # Using it would most likely result in an error, so if the user - # did not specify the type of the elements within the list, we - # try to infer it: - try: - kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple - except IndexError: - # get_args returned an empty tuple, type cannot be inferred - raise ValueError( - f"Cannot infer type of items in field: {field.name}. " - "Try using a parameterized type hint, or " - "specifying the type explicitly using metadata['type']" - ) - - if field.default == field.default_factory == MISSING and not positional: - kwargs["required"] = True - else: - kwargs["default"] = MISSING - - if field.type is bool: - _handle_bool_type(field, args, kwargs) - elif get_origin(field.type) is Union: - if field.metadata.get("type") is None: - # Optional[X] is equivalent to Union[X, None]. - f_args = get_args(field.type) - if len(f_args) == 2 and NoneType in f_args: - arg = next(a for a in f_args if a is not NoneType) - kwargs["type"] = arg - else: - raise TypeError( - "For Union types other than 'Optional', a custom 'type' must be specified using " - "'metadata'." - ) + return args, kwargs + + +def add_dataclass_options( + options_class: Type[OptionsType], parser: argparse.ArgumentParser +) -> None: + """Adds options given as dataclass fields to the parser.""" + if not is_dataclass(options_class): + raise TypeError("cls must be a dataclass") + + for field in fields(options_class): + args, kwargs = extract_argparse_kwargs(field) if "group" in field.metadata: _handle_argument_group(parser, field, args, kwargs) @@ -465,7 +477,7 @@ class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]): def __init__(self, options_class: Type[OptionsType], *args, **kwargs): super().__init__(*args, **kwargs) self._options_type: Type[OptionsType] = options_class - _add_dataclass_options(options_class, self) + add_dataclass_options(options_class, self) def parse_args(self, args: ArgsType = None, namespace=None) -> OptionsType: """Parse arguments and return as the dataclass type.""" diff --git a/tests/test_functional.py b/tests/test_functional.py index ec41f9f..5c41be4 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,3 +1,4 @@ +import argparse import sys import unittest import datetime as dt @@ -5,7 +6,7 @@ from typing import Optional, Union -from argparse_dataclass import parse_args, parse_known_args +from argparse_dataclass import add_dataclass_options, parse_args, parse_known_args class NegativeTestHelper: @@ -53,6 +54,21 @@ class Opt: self.assertRaises(TypeError, parse_args, Opt, []) + def test_add_dataclass_options(self): + @dataclass + class Opt: + x: int = 42 + y: bool = False + argpument_parser = argparse.ArgumentParser() + add_dataclass_options(Opt, argpument_parser) + params = argpument_parser.parse_args([]) + print(params) + self.assertEqual(42, params.x) + self.assertEqual(False, params.y) + params = argpument_parser.parse_args(["--x=10", "--y"]) + self.assertEqual(10, params.x) + self.assertEqual(True, params.y) + def test_bool_no_default(self): @dataclass class Opt: