diff --git a/injector/__init__.py b/injector/__init__.py index 62d7d99..4f7d580 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -1355,9 +1355,8 @@ def _recreate_annotated_origin(annotated_type: Any) -> Any: if only_explicit_bindings and _inject_marker not in metadata or _noinject_marker in metadata: del bindings[k] elif _is_specialization(v, Union) or _is_new_union_type(v): - # We don't treat Optional parameters in any special way at the moment. union_members = v.__args__ - new_members = tuple(set(union_members) - {type(None)}) + new_members = tuple(set(union_members)) # mypy stared complaining about this line for some reason: # error: Variable "new_members" is not valid as a type new_union = Union[new_members] # type: ignore diff --git a/injector_test.py b/injector_test.py index 4591d86..c9336e3 100644 --- a/injector_test.py +++ b/injector_test.py @@ -2031,7 +2031,7 @@ def test_get_bindings_for_pep_604(): def function1(a: int | None) -> None: pass - assert get_bindings(function1) == {'a': int} + assert get_bindings(function1) == {'a': Union[int, None]} @inject def function1(a: int | str) -> None: @@ -2246,3 +2246,54 @@ def provide_second(self) -> Annotated[str, 'second']: injector = Injector(module) assert injector.get(Annotated[str, 'first']) == 'Bob' assert injector.get(Annotated[str, 'second']) == 'Iger' + + +@pytest.mark.parametrize( + "provide", + [True, False] +) +def test_optional_parameter_optional_provider(provide: bool) -> None: + class Dependency: + pass + + class Dependent: + @inject + def __init__(self, dependency: Dependency | None = None) -> None: + self.dependency = dependency + + class MyModule(Module): + @provider + def optional(self) -> Dependency | None: + return Dependency() if provide else None + + injector = Injector([MyModule()]) + instance = injector.get(Dependent) + assert instance.dependency or not provide + + +@pytest.mark.parametrize( + "provide", + [True, False] +) +def test_required_parameter_optional_provider(provide: bool) -> None: + class Dependency: + def __init__(self, transitive: object = object()) -> None: + self.dependency = transitive + + class Dependent: + @inject + def __init__(self, dependency: Dependency) -> None: + self.dependency = dependency + + class MyModule(Module): + @provider + def optional(self) -> Dependency | None: + return Dependency(transitive="from provider") if provide else None + + injector = Injector([MyModule()]) + instance = injector.get(Dependent) + # the Dependency | None provider should never be used, and instead + # the default behavior of providing instances of Dependency should be + # used meaning in both provider branches this will be true + assert instance.dependency is not None + assert instance.dependency.dependency != "from provider" \ No newline at end of file