Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 97 additions & 85 deletions argparse_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@
"""

import argparse
from collections import namedtuple
from argparse import BooleanOptionalAction
from typing import (
TypeVar,
Expand Down Expand Up @@ -274,7 +275,7 @@
def parse_args(options_class: Type[OptionsType], args: ArgsType = None) -> OptionsType:
"""Parse arguments and return as the dataclass type."""
parser = argparse.ArgumentParser()
_add_dataclass_options(options_class, parser)
add_dataclass_options(options_class, parser)
kwargs = _get_kwargs(parser.parse_args(args))
return options_class(**kwargs)

Expand All @@ -286,102 +287,113 @@ def parse_known_args(
and list of remaining arguments.
"""
parser = argparse.ArgumentParser()
_add_dataclass_options(options_class, parser)
add_dataclass_options(options_class, parser)
namespace, others = parser.parse_known_args(args=args)
kwargs = _get_kwargs(namespace)
return options_class(**kwargs), others


def _add_dataclass_options(
options_class: Type[OptionsType], parser: argparse.ArgumentParser
) -> None:
if not is_dataclass(options_class):
raise TypeError("cls must be a dataclass")

for field in fields(options_class):
args = field.metadata.get("args", [f"--{_get_arg_name(field)}"])
positional = not args[0].startswith("-")
kwargs = {
"type": field.metadata.get("type", field.type),
"help": field.metadata.get("help", None),
}

if field.metadata.get("args") and not positional:
# We want to ensure that we store the argument based on the
# name of the field and not whatever flag name was provided
kwargs["dest"] = field.name
def extract_argparse_kwargs(field: Field[Any]) -> Tuple[List[str], Dict[str, Any]]:
"""Extract kwargs of ArgumentParser.add_argument from a dataclass field.

Returns pair of (args, kwargs) to be passed to ArgumentParser.add_argument.
"""
args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"])
positional = not args[0].startswith("-")
kwargs = {
"type": field.metadata.get("type", field.type),
"help": field.metadata.get("help", None),
}

if field.metadata.get("args") and not positional:
# We want to ensure that we store the argument based on the
# name of the field and not whatever flag name was provided
kwargs["dest"] = field.name

if field.metadata.get("choices") is not None:
kwargs["choices"] = field.metadata["choices"]

# Support Literal types as an alternative means of specifying choices.
if get_origin(field.type) is Literal:
# Prohibit a potential collision with the choices field
if field.metadata.get("choices") is not None:
kwargs["choices"] = field.metadata["choices"]

# Support Literal types as an alternative means of specifying choices.
if get_origin(field.type) is Literal:
# Prohibit a potential collision with the choices field
if field.metadata.get("choices") is not None:
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Literal type arguments should not be combined with choices in the metadata. "
"Remove the redundant choices field from the metadata."
)

# Get the types of the arguments of the Literal
types = [type(arg) for arg in get_args(field.type)]

# Make sure just a single type has been used
if len(set(types)) > 1:
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Literal type arguments should contain choices of a single type. "
f"Instead, {len(set(types))} types where found: "
+ ", ".join([type_.__name__ for type_ in set(types)])
+ "."
)

# Overwrite the type kwarg
kwargs["type"] = types[0]
# Use the literal arguments as choices
kwargs["choices"] = get_args(field.type)

if field.metadata.get("metavar") is not None:
kwargs["metavar"] = field.metadata["metavar"]

if field.metadata.get("nargs") is not None:
kwargs["nargs"] = field.metadata["nargs"]
if field.metadata.get("type") is None:
# When nargs is specified, field.type should be a list,
# or something equivalent, like typing.List.
# Using it would most likely result in an error, so if the user
# did not specify the type of the elements within the list, we
# try to infer it:
try:
kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple
except IndexError:
# get_args returned an empty tuple, type cannot be inferred
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Literal type arguments should not be combined with choices in the metadata. "
"Remove the redundant choices field from the metadata."
"Try using a parameterized type hint, or "
"specifying the type explicitly using metadata['type']"
)

# Get the types of the arguments of the Literal
types = [type(arg) for arg in get_args(field.type)]

# Make sure just a single type has been used
if len(set(types)) > 1:
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Literal type arguments should contain choices of a single type. "
f"Instead, {len(set(types))} types where found: "
+ ", ".join([type_.__name__ for type_ in set(types)])
+ "."
if field.default == field.default_factory == MISSING and not positional:
kwargs["required"] = True
else:
kwargs["default"] = MISSING

if field.type is bool:
_handle_bool_type(field, args, kwargs)
elif get_origin(field.type) is Union:
if field.metadata.get("type") is None:
# Optional[X] is equivalent to Union[X, None].
f_args = get_args(field.type)
if len(f_args) == 2 and NoneType in f_args:
arg = next(a for a in f_args if a is not NoneType)
kwargs["type"] = arg
else:
raise TypeError(
"For Union types other than 'Optional', a custom 'type' must be specified using "
"'metadata'."
)

# Overwrite the type kwarg
kwargs["type"] = types[0]
# Use the literal arguments as choices
kwargs["choices"] = get_args(field.type)

if field.metadata.get("metavar") is not None:
kwargs["metavar"] = field.metadata["metavar"]

if field.metadata.get("nargs") is not None:
kwargs["nargs"] = field.metadata["nargs"]
if field.metadata.get("type") is None:
# When nargs is specified, field.type should be a list,
# or something equivalent, like typing.List.
# Using it would most likely result in an error, so if the user
# did not specify the type of the elements within the list, we
# try to infer it:
try:
kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple
except IndexError:
# get_args returned an empty tuple, type cannot be inferred
raise ValueError(
f"Cannot infer type of items in field: {field.name}. "
"Try using a parameterized type hint, or "
"specifying the type explicitly using metadata['type']"
)

if field.default == field.default_factory == MISSING and not positional:
kwargs["required"] = True
else:
kwargs["default"] = MISSING

if field.type is bool:
_handle_bool_type(field, args, kwargs)
elif get_origin(field.type) is Union:
if field.metadata.get("type") is None:
# Optional[X] is equivalent to Union[X, None].
f_args = get_args(field.type)
if len(f_args) == 2 and NoneType in f_args:
arg = next(a for a in f_args if a is not NoneType)
kwargs["type"] = arg
else:
raise TypeError(
"For Union types other than 'Optional', a custom 'type' must be specified using "
"'metadata'."
)
return args, kwargs


def add_dataclass_options(
options_class: Type[OptionsType], parser: argparse.ArgumentParser
) -> None:
"""Adds options given as dataclass fields to the parser."""
if not is_dataclass(options_class):
raise TypeError("cls must be a dataclass")

for field in fields(options_class):
args, kwargs = extract_argparse_kwargs(field)

if "group" in field.metadata:
_handle_argument_group(parser, field, args, kwargs)
Expand Down Expand Up @@ -465,7 +477,7 @@ class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]):
def __init__(self, options_class: Type[OptionsType], *args, **kwargs):
super().__init__(*args, **kwargs)
self._options_type: Type[OptionsType] = options_class
_add_dataclass_options(options_class, self)
add_dataclass_options(options_class, self)

def parse_args(self, args: ArgsType = None, namespace=None) -> OptionsType:
"""Parse arguments and return as the dataclass type."""
Expand Down
18 changes: 17 additions & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
import sys
import unittest
import datetime as dt
from dataclasses import dataclass, field

from typing import Optional, Union

from argparse_dataclass import parse_args, parse_known_args
from argparse_dataclass import add_dataclass_options, parse_args, parse_known_args


class NegativeTestHelper:
Expand Down Expand Up @@ -53,6 +54,21 @@ class Opt:

self.assertRaises(TypeError, parse_args, Opt, [])

def test_add_dataclass_options(self):
@dataclass
class Opt:
x: int = 42
y: bool = False
argpument_parser = argparse.ArgumentParser()
add_dataclass_options(Opt, argpument_parser)
params = argpument_parser.parse_args([])
print(params)
self.assertEqual(42, params.x)
self.assertEqual(False, params.y)
params = argpument_parser.parse_args(["--x=10", "--y"])
self.assertEqual(10, params.x)
self.assertEqual(True, params.y)

def test_bool_no_default(self):
@dataclass
class Opt:
Expand Down
Loading