Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,9 +1355,8 @@ def _recreate_annotated_origin(annotated_type: Any) -> Any:
if only_explicit_bindings and _inject_marker not in metadata or _noinject_marker in metadata:
del bindings[k]
elif _is_specialization(v, Union) or _is_new_union_type(v):
# We don't treat Optional parameters in any special way at the moment.
union_members = v.__args__
new_members = tuple(set(union_members) - {type(None)})
new_members = tuple(set(union_members))
# mypy stared complaining about this line for some reason:
# error: Variable "new_members" is not valid as a type
new_union = Union[new_members] # type: ignore
Expand Down
53 changes: 52 additions & 1 deletion injector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,7 +2031,7 @@ def test_get_bindings_for_pep_604():
def function1(a: int | None) -> None:
pass

assert get_bindings(function1) == {'a': int}
assert get_bindings(function1) == {'a': Union[int, None]}

@inject
def function1(a: int | str) -> None:
Expand Down Expand Up @@ -2246,3 +2246,54 @@ def provide_second(self) -> Annotated[str, 'second']:
injector = Injector(module)
assert injector.get(Annotated[str, 'first']) == 'Bob'
assert injector.get(Annotated[str, 'second']) == 'Iger'


@pytest.mark.parametrize(
"provide",
[True, False]
)
def test_optional_parameter_optional_provider(provide: bool) -> None:
class Dependency:
pass

class Dependent:
@inject
def __init__(self, dependency: Dependency | None = None) -> None:
self.dependency = dependency

class MyModule(Module):
@provider
def optional(self) -> Dependency | None:
return Dependency() if provide else None

injector = Injector([MyModule()])
instance = injector.get(Dependent)
assert instance.dependency or not provide


@pytest.mark.parametrize(
"provide",
[True, False]
)
def test_required_parameter_optional_provider(provide: bool) -> None:
class Dependency:
def __init__(self, transitive: object = object()) -> None:
self.dependency = transitive

class Dependent:
@inject
def __init__(self, dependency: Dependency) -> None:
self.dependency = dependency

class MyModule(Module):
@provider
def optional(self) -> Dependency | None:
return Dependency(transitive="from provider") if provide else None

injector = Injector([MyModule()])
instance = injector.get(Dependent)
# the Dependency | None provider should never be used, and instead
# the default behavior of providing instances of Dependency should be
# used meaning in both provider branches this will be true
assert instance.dependency is not None
assert instance.dependency.dependency != "from provider"
Loading