From ed18356b795c19b94cf04ddf1d245f43581cac33 Mon Sep 17 00:00:00 2001 From: Johannes Koester Date: Mon, 11 Sep 2023 09:24:16 +0200 Subject: [PATCH 1/6] feat: separate args and kwargs retrieval into helper functions and expose add_dataclass_options in the public API --- argparse_dataclass.py | 177 ++++++++++++++++++++++-------------------- 1 file changed, 94 insertions(+), 83 deletions(-) diff --git a/argparse_dataclass.py b/argparse_dataclass.py index 3da8556..d9b5f63 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -308,7 +308,7 @@ def format_usage(self): def parse_args(options_class: Type[OptionsType], args: ArgsType = None) -> OptionsType: """Parse arguments and return as the dataclass type.""" parser = argparse.ArgumentParser() - _add_dataclass_options(options_class, parser) + add_dataclass_options(options_class, parser) kwargs = _get_kwargs(parser.parse_args(args)) return options_class(**kwargs) @@ -320,102 +320,113 @@ def parse_known_args( and list of remaining arguments. """ parser = argparse.ArgumentParser() - _add_dataclass_options(options_class, parser) + add_dataclass_options(options_class, parser) namespace, others = parser.parse_known_args(args=args) kwargs = _get_kwargs(namespace) return options_class(**kwargs), others -def _add_dataclass_options( - options_class: Type[OptionsType], parser: argparse.ArgumentParser -) -> None: - if not is_dataclass(options_class): - raise TypeError("cls must be a dataclass") +def get_field_args(field: Field[Any]) -> List[str]: + return field.metadata.get("args", [f"--{field.name.replace('_', '-')}"]) - for field in fields(options_class): - args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"]) - positional = not args[0].startswith("-") - kwargs = { - "type": field.metadata.get("type", field.type), - "help": field.metadata.get("help", None), - } - - if field.metadata.get("args") and not positional: - # We want to ensure that we store the argument based on the - # name of the field and not whatever flag name was provided - kwargs["dest"] = field.name +def get_field_kwargs(field: Field[Any]) -> Dict[str, Any]: + args = get_field_args(field) + positional = not args[0].startswith("-") + kwargs = { + "type": field.metadata.get("type", field.type), + "help": field.metadata.get("help", None), + } + + if field.metadata.get("args") and not positional: + # We want to ensure that we store the argument based on the + # name of the field and not whatever flag name was provided + kwargs["dest"] = field.name + + if field.metadata.get("choices") is not None: + kwargs["choices"] = field.metadata["choices"] + + # Support Literal types as an alternative means of specifying choices. + if get_origin(field.type) is Literal: + # Prohibit a potential collision with the choices field if field.metadata.get("choices") is not None: - kwargs["choices"] = field.metadata["choices"] + raise ValueError( + f"Cannot infer type of items in field: {field.name}. " + "Literal type arguments should not be combined with choices in the metadata. " + "Remove the redundant choices field from the metadata." + ) + + # Get the types of the arguments of the Literal + types = [type(arg) for arg in get_args(field.type)] + + # Make sure just a single type has been used + if len(set(types)) > 1: + raise ValueError( + f"Cannot infer type of items in field: {field.name}. " + "Literal type arguments should contain choices of a single type. " + f"Instead, {len(set(types))} types where found: " + + ", ".join([type_.__name__ for type_ in set(types)]) + + "." + ) - # Support Literal types as an alternative means of specifying choices. - if get_origin(field.type) is Literal: - # Prohibit a potential collision with the choices field - if field.metadata.get("choices") is not None: + # Overwrite the type kwarg + kwargs["type"] = types[0] + # Use the literal arguments as choices + kwargs["choices"] = get_args(field.type) + + if field.metadata.get("metavar") is not None: + kwargs["metavar"] = field.metadata["metavar"] + + if field.metadata.get("nargs") is not None: + kwargs["nargs"] = field.metadata["nargs"] + if field.metadata.get("type") is None: + # When nargs is specified, field.type should be a list, + # or something equivalent, like typing.List. + # Using it would most likely result in an error, so if the user + # did not specify the type of the elements within the list, we + # try to infer it: + try: + kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple + except IndexError: + # get_args returned an empty tuple, type cannot be inferred raise ValueError( f"Cannot infer type of items in field: {field.name}. " - "Literal type arguments should not be combined with choices in the metadata. " - "Remove the redundant choices field from the metadata." + "Try using a parameterized type hint, or " + "specifying the type explicitly using metadata['type']" ) - # Get the types of the arguments of the Literal - types = [type(arg) for arg in get_args(field.type)] - - # Make sure just a single type has been used - if len(set(types)) > 1: - raise ValueError( - f"Cannot infer type of items in field: {field.name}. " - "Literal type arguments should contain choices of a single type. " - f"Instead, {len(set(types))} types where found: " - + ", ".join([type_.__name__ for type_ in set(types)]) - + "." + if field.default == field.default_factory == MISSING and not positional: + kwargs["required"] = True + else: + kwargs["default"] = MISSING + + if field.type is bool: + _handle_bool_type(field, args, kwargs) + elif get_origin(field.type) is Union: + if field.metadata.get("type") is None: + # Optional[X] is equivalent to Union[X, None]. + f_args = get_args(field.type) + if len(f_args) == 2 and NoneType in f_args: + arg = next(a for a in f_args if a is not NoneType) + kwargs["type"] = arg + else: + raise TypeError( + "For Union types other than 'Optional', a custom 'type' must be specified using " + "'metadata'." ) - # Overwrite the type kwarg - kwargs["type"] = types[0] - # Use the literal arguments as choices - kwargs["choices"] = get_args(field.type) - - if field.metadata.get("metavar") is not None: - kwargs["metavar"] = field.metadata["metavar"] - - if field.metadata.get("nargs") is not None: - kwargs["nargs"] = field.metadata["nargs"] - if field.metadata.get("type") is None: - # When nargs is specified, field.type should be a list, - # or something equivalent, like typing.List. - # Using it would most likely result in an error, so if the user - # did not specify the type of the elements within the list, we - # try to infer it: - try: - kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple - except IndexError: - # get_args returned an empty tuple, type cannot be inferred - raise ValueError( - f"Cannot infer type of items in field: {field.name}. " - "Try using a parameterized type hint, or " - "specifying the type explicitly using metadata['type']" - ) - - if field.default == field.default_factory == MISSING and not positional: - kwargs["required"] = True - else: - kwargs["default"] = MISSING - - if field.type is bool: - _handle_bool_type(field, args, kwargs) - elif get_origin(field.type) is Union: - if field.metadata.get("type") is None: - # Optional[X] is equivalent to Union[X, None]. - f_args = get_args(field.type) - if len(f_args) == 2 and NoneType in f_args: - arg = next(a for a in f_args if a is not NoneType) - kwargs["type"] = arg - else: - raise TypeError( - "For Union types other than 'Optional', a custom 'type' must be specified using " - "'metadata'." - ) + return kwargs + + +def add_dataclass_options( + options_class: Type[OptionsType], parser: argparse.ArgumentParser +) -> None: + if not is_dataclass(options_class): + raise TypeError("cls must be a dataclass") + + for field in fields(options_class): + args = get_field_args(field) + kwargs = get_field_kwargs(field) if "group" in field.metadata: _handle_argument_group(parser, field, args, kwargs) @@ -493,7 +504,7 @@ class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]): def __init__(self, options_class: Type[OptionsType], *args, **kwargs): super().__init__(*args, **kwargs) self._options_type: Type[OptionsType] = options_class - _add_dataclass_options(options_class, self) + add_dataclass_options(options_class, self) def parse_args(self, args: ArgsType = None, namespace=None) -> OptionsType: """Parse arguments and return as the dataclass type.""" From 15c8f4286b6939a3c5f5d8d789a8f5f20f5c0c55 Mon Sep 17 00:00:00 2001 From: Johannes Koester Date: Mon, 11 Sep 2023 09:28:27 +0200 Subject: [PATCH 2/6] docs: add docstrings --- argparse_dataclass.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/argparse_dataclass.py b/argparse_dataclass.py index d9b5f63..cad3e98 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -326,12 +326,14 @@ def parse_known_args( return options_class(**kwargs), others -def get_field_args(field: Field[Any]) -> List[str]: +def field_to_argument_args(field: Field[Any]) -> List[str]: + """Extract args of ArgumentParser.add_argument from a dataclass field.""" return field.metadata.get("args", [f"--{field.name.replace('_', '-')}"]) -def get_field_kwargs(field: Field[Any]) -> Dict[str, Any]: - args = get_field_args(field) +def field_to_argument_kwargs(field: Field[Any]) -> Dict[str, Any]: + """Extract kwargs of ArgumentParser.add_argument from a dataclass field.""" + args = field_to_argument_args(field) positional = not args[0].startswith("-") kwargs = { "type": field.metadata.get("type", field.type), @@ -421,12 +423,13 @@ def get_field_kwargs(field: Field[Any]) -> Dict[str, Any]: def add_dataclass_options( options_class: Type[OptionsType], parser: argparse.ArgumentParser ) -> None: + """Adds options given as dataclass fields to the parser.""" if not is_dataclass(options_class): raise TypeError("cls must be a dataclass") for field in fields(options_class): - args = get_field_args(field) - kwargs = get_field_kwargs(field) + args = field_to_argument_args(field) + kwargs = field_to_argument_kwargs(field) if "group" in field.metadata: _handle_argument_group(parser, field, args, kwargs) From 2561f0049a357df91ccbcdf0455a2f5f4867d2b4 Mon Sep 17 00:00:00 2001 From: Johannes Koester Date: Mon, 11 Sep 2023 10:00:19 +0200 Subject: [PATCH 3/6] trigger CI From 099b7d9f38f2477ea76dcd20ad0e5763e4981664 Mon Sep 17 00:00:00 2001 From: Johannes Koester Date: Mon, 11 Sep 2023 10:40:47 +0200 Subject: [PATCH 4/6] fixes --- argparse_dataclass.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/argparse_dataclass.py b/argparse_dataclass.py index cad3e98..c160f8b 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -224,6 +224,7 @@ """ import argparse +from collections import namedtuple from typing import ( TypeVar, Optional, @@ -326,14 +327,12 @@ def parse_known_args( return options_class(**kwargs), others -def field_to_argument_args(field: Field[Any]) -> List[str]: - """Extract args of ArgumentParser.add_argument from a dataclass field.""" - return field.metadata.get("args", [f"--{field.name.replace('_', '-')}"]) +def field_to_argument_args(field: Field[Any]) -> Tuple[List[str], Dict[str, Any]]: + """Extract kwargs of ArgumentParser.add_argument from a dataclass field. - -def field_to_argument_kwargs(field: Field[Any]) -> Dict[str, Any]: - """Extract kwargs of ArgumentParser.add_argument from a dataclass field.""" - args = field_to_argument_args(field) + Returns pair of (args, kwargs) to be passed to ArgumentParser.add_argument. + """ + args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"]) positional = not args[0].startswith("-") kwargs = { "type": field.metadata.get("type", field.type), @@ -417,7 +416,7 @@ def field_to_argument_kwargs(field: Field[Any]) -> Dict[str, Any]: "'metadata'." ) - return kwargs + return args, kwargs def add_dataclass_options( @@ -428,8 +427,7 @@ def add_dataclass_options( raise TypeError("cls must be a dataclass") for field in fields(options_class): - args = field_to_argument_args(field) - kwargs = field_to_argument_kwargs(field) + args, kwargs = field_to_argument_args(field) if "group" in field.metadata: _handle_argument_group(parser, field, args, kwargs) From e6d4b5a2c9d78a4192da1a7b376afeb71804a45f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20K=C3=B6ster?= Date: Fri, 15 Sep 2023 09:43:25 +0200 Subject: [PATCH 5/6] Rename function --- argparse_dataclass.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/argparse_dataclass.py b/argparse_dataclass.py index c160f8b..bc9b314 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -327,7 +327,7 @@ def parse_known_args( return options_class(**kwargs), others -def field_to_argument_args(field: Field[Any]) -> Tuple[List[str], Dict[str, Any]]: +def extract_argparse_kwargs(field: Field[Any]) -> Tuple[List[str], Dict[str, Any]]: """Extract kwargs of ArgumentParser.add_argument from a dataclass field. Returns pair of (args, kwargs) to be passed to ArgumentParser.add_argument. @@ -427,7 +427,7 @@ def add_dataclass_options( raise TypeError("cls must be a dataclass") for field in fields(options_class): - args, kwargs = field_to_argument_args(field) + args, kwargs = extract_argparse_kwargs(field) if "group" in field.metadata: _handle_argument_group(parser, field, args, kwargs) From 281adc320c1a8872d240bb844967d56cd77b1456 Mon Sep 17 00:00:00 2001 From: Johannes Koester Date: Wed, 19 Nov 2025 20:58:11 +0100 Subject: [PATCH 6/6] add testcase --- tests/test_functional.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 70ec48f..4bf67a2 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,3 +1,4 @@ +import argparse import sys import unittest import datetime as dt @@ -5,7 +6,7 @@ from typing import List, Optional, Union -from argparse_dataclass import parse_args, parse_known_args +from argparse_dataclass import add_dataclass_options, parse_args, parse_known_args class NegativeTestHelper: @@ -53,6 +54,21 @@ class Opt: self.assertRaises(TypeError, parse_args, Opt, []) + def test_add_dataclass_options(self): + @dataclass + class Opt: + x: int = 42 + y: bool = False + argpument_parser = argparse.ArgumentParser() + add_dataclass_options(Opt, argpument_parser) + params = argpument_parser.parse_args([]) + print(params) + self.assertEqual(42, params.x) + self.assertEqual(False, params.y) + params = argpument_parser.parse_args(["--x=10", "--y"]) + self.assertEqual(10, params.x) + self.assertEqual(True, params.y) + def test_bool_no_default(self): @dataclass class Opt: