From d17ba52595b17ab8da2da80cb902778ffa30968f Mon Sep 17 00:00:00 2001 From: Alexander Date: Sat, 18 Oct 2025 06:15:37 +0800 Subject: [PATCH] feat: LazyProvider --- docs/experimental/lazy.md | 52 ++++++ mkdocs.yml | 2 + pyproject.toml | 2 +- tests/experimental/__init__.py | 0 tests/experimental/test_container_1.py | 15 ++ tests/experimental/test_container_2.py | 145 ++++++++++++++++ tests/experimental/test_lazy_provider.py | 30 ++++ that_depends/experimental/__init__.py | 8 + that_depends/experimental/providers.py | 179 ++++++++++++++++++++ that_depends/meta.py | 2 +- that_depends/providers/context_resources.py | 3 +- 11 files changed, 434 insertions(+), 4 deletions(-) create mode 100644 docs/experimental/lazy.md create mode 100644 tests/experimental/__init__.py create mode 100644 tests/experimental/test_container_1.py create mode 100644 tests/experimental/test_container_2.py create mode 100644 tests/experimental/test_lazy_provider.py create mode 100644 that_depends/experimental/__init__.py create mode 100644 that_depends/experimental/providers.py diff --git a/docs/experimental/lazy.md b/docs/experimental/lazy.md new file mode 100644 index 00000000..1cb24520 --- /dev/null +++ b/docs/experimental/lazy.md @@ -0,0 +1,52 @@ +# Lazy Provider + +The `LazyProvider` enables you to reference other providers without explicitly +importing them into your module. + +This can be helpful if you have a circular dependency between providers in +multiple containers. + + +## Creating a Lazy Provider + +=== "Single import string" + ```python + from that_depends.experimental import LazyProvider + + lazy_p = LazyProvider("full.import.string.including.attributes") + ``` +=== "Separate module and provider" + ```python + from that_depends.experimental import LazyProvider + + lazy_p = LazyProvider(module_string="my.module", provider_string="attribute.path") + ``` + + +## Usage + +You can use the lazy provider in exactly the same way as you would use the referenced provider. + +```python +# first_container.py +from that_depends import BaseContainer, providers, ContextScopes + +def my_creator(): + yield 42 + +class FirstContainer(BaseContainer): + value_provider = providers.ContextResource(my_creator).with_config(scope=ContextScopes.APP) +``` + +You can lazily import this provider: +```python +# second_container.py +from that_depends.experimental import LazyProvider +from that_depends import BaseContainer, providers +class SecondContainer(BaseContainer): + lazy_value = LazyProvider("first_container.FirstContainer.value_provider") + + +with SecondContainer.lazy_value.context_sync(force=True): + SecondContainer.lazy_value.resolve_sync() # 42 +``` diff --git a/mkdocs.yml b/mkdocs.yml index 7ce31b69..0226005e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -23,6 +23,8 @@ nav: - Selector: providers/selector.md - Singletons: providers/singleton.md - State: providers/state.md + - Experimental Features: + - Lazy Provider: experimental/lazy.md - Integrations: - FastAPI: integrations/fastapi.md diff --git a/pyproject.toml b/pyproject.toml index 398e1a1c..c68b78a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ fastapi = [ "fastapi", ] faststream = [ - "faststream" + "faststream<0.6.0" ] [project.urls] diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/experimental/test_container_1.py b/tests/experimental/test_container_1.py new file mode 100644 index 00000000..8ecc0de0 --- /dev/null +++ b/tests/experimental/test_container_1.py @@ -0,0 +1,15 @@ +from tests.experimental.test_container_2 import Container2 +from that_depends import BaseContainer, providers +from that_depends.experimental import LazyProvider + + +class Container1(BaseContainer): + """Test Container 1.""" + + alias = "container_1" + obj_1 = providers.Object(1) + obj_2 = LazyProvider(module_string="tests.experimental.test_container_2", provider_string="Container2.obj_2") + + +def test_lazy_provider_resolution_sync() -> None: + assert Container2.obj_2.resolve_sync() == 2 # noqa: PLR2004 diff --git a/tests/experimental/test_container_2.py b/tests/experimental/test_container_2.py new file mode 100644 index 00000000..3d50c221 --- /dev/null +++ b/tests/experimental/test_container_2.py @@ -0,0 +1,145 @@ +import random +from collections.abc import AsyncIterator, Iterator + +import pytest +import typing_extensions + +from that_depends import BaseContainer, ContextScopes, container_context, providers +from that_depends.experimental import LazyProvider + + +class _RandomWrapper: + def __init__(self) -> None: + self.value = random.random() + + @typing_extensions.override + def __eq__(self, other: object) -> bool: + if isinstance(other, _RandomWrapper): + return self.value == other.value + return False # pragma: nocover + + def __hash__(self) -> int: + return 0 # pragma: nocover + + +async def _async_creator() -> AsyncIterator[float]: + yield random.random() + + +def _sync_creator() -> Iterator[_RandomWrapper]: + yield _RandomWrapper() + + +class Container2(BaseContainer): + """Test Container 2.""" + + alias = "container_2" + default_scope = ContextScopes.APP + obj_1 = LazyProvider("tests.experimental.test_container_1.Container1.obj_1") + obj_2 = providers.Object(2) + async_context_provider = providers.ContextResource(_async_creator) + sync_context_provider = providers.ContextResource(_sync_creator) + singleton_provider = providers.Singleton(lambda: random.random()) + + +async def test_lazy_provider_resolution_async() -> None: + assert await Container2.obj_1.resolve() == 1 + + +def test_lazy_provider_override_sync() -> None: + override_value = 42 + Container2.obj_1.override_sync(override_value) + assert Container2.obj_1.resolve_sync() == override_value + Container2.obj_1.reset_override_sync() + assert Container2.obj_1.resolve_sync() == 1 + + +async def test_lazy_provider_override_async() -> None: + override_value = 42 + await Container2.obj_1.override(override_value) + assert await Container2.obj_1.resolve() == override_value + await Container2.obj_1.reset_override() + assert await Container2.obj_1.resolve() == 1 + + +def test_lazy_provider_invalid_state() -> None: + lazy_provider = LazyProvider( + module_string="tests.experimental.test_container_2", provider_string="Container2.sync_context_provider" + ) + lazy_provider._module_string = None + with pytest.raises(RuntimeError): + lazy_provider.resolve_sync() + + +async def test_lazy_provider_context_resource_async() -> None: + lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.async_context_provider") + async with lazy_provider.context_async(force=True): + assert await lazy_provider.resolve() == await Container2.async_context_provider.resolve() + async with Container2.async_context_provider.context_async(force=True): + assert await lazy_provider.resolve() == await Container2.async_context_provider.resolve() + + with pytest.raises(RuntimeError): + await lazy_provider.resolve() + + async with container_context(Container2, scope=ContextScopes.APP): + assert await lazy_provider.resolve() == await Container2.async_context_provider.resolve() + + assert lazy_provider.get_scope() == ContextScopes.APP + + assert lazy_provider.supports_context_sync() is False + + +def test_lazy_provider_context_resource_sync() -> None: + lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.sync_context_provider") + with lazy_provider.context_sync(force=True): + assert lazy_provider.resolve_sync() == Container2.sync_context_provider.resolve_sync() + with Container2.sync_context_provider.context_sync(force=True): + assert lazy_provider.resolve_sync() == Container2.sync_context_provider.resolve_sync() + + with pytest.raises(RuntimeError): + lazy_provider.resolve_sync() + + with container_context(Container2, scope=ContextScopes.APP): + assert lazy_provider.resolve_sync() == Container2.sync_context_provider.resolve_sync() + + assert lazy_provider.get_scope() == ContextScopes.APP + + assert lazy_provider.supports_context_sync() is True + + +async def test_lazy_provider_tear_down_async() -> None: + lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.singleton_provider") + assert lazy_provider.resolve_sync() == Container2.singleton_provider.resolve_sync() + + await lazy_provider.tear_down() + + assert await lazy_provider.resolve() == Container2.singleton_provider.resolve_sync() + + +def test_lazy_provider_tear_down_sync() -> None: + lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.singleton_provider") + assert lazy_provider.resolve_sync() == Container2.singleton_provider.resolve_sync() + + lazy_provider.tear_down_sync() + + assert lazy_provider.resolve_sync() == Container2.singleton_provider.resolve_sync() + + +async def test_lazy_provider_not_implemented() -> None: + lazy_provider = Container2.obj_1 + with pytest.raises(NotImplementedError): + lazy_provider.get_scope() + with pytest.raises(NotImplementedError): + lazy_provider.context_sync() + with pytest.raises(NotImplementedError): + lazy_provider.context_async() + with pytest.raises(NotImplementedError): + lazy_provider.tear_down_sync() + with pytest.raises(NotImplementedError): + await lazy_provider.tear_down() + + +def test_lazy_provider_attr_getter() -> None: + lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.sync_context_provider") + with lazy_provider.context_sync(force=True): + assert isinstance(lazy_provider.value.resolve_sync(), float) diff --git a/tests/experimental/test_lazy_provider.py b/tests/experimental/test_lazy_provider.py new file mode 100644 index 00000000..083f97fe --- /dev/null +++ b/tests/experimental/test_lazy_provider.py @@ -0,0 +1,30 @@ +import pytest + +from that_depends.experimental import LazyProvider + + +def test_lazy_provider_incorrect_initialization() -> None: + with pytest.raises( + ValueError, + match=r"You must provide either import_string " + "OR both module_string AND provider_string, but not both or neither.", + ): + LazyProvider(module_string="3213") # type: ignore[call-overload] + + with pytest.raises(ValueError, match=r"Invalid import_string ''"): + LazyProvider("") + + with pytest.raises(ValueError, match=r"Invalid provider_string ''"): + LazyProvider(module_string="some.module", provider_string="") + + with pytest.raises(ValueError, match=r"Invalid module_string '.'"): + LazyProvider(module_string=".", provider_string="SomeProvider") + + with pytest.raises(ValueError, match=r"Invalid import_string 'import.'"): + LazyProvider("import.") + + +def test_lazy_provider_incorrect_import_string() -> None: + p = LazyProvider("some.random.path") + with pytest.raises(ImportError): + p.resolve_sync() diff --git a/that_depends/experimental/__init__.py b/that_depends/experimental/__init__.py new file mode 100644 index 00000000..10d7d24d --- /dev/null +++ b/that_depends/experimental/__init__.py @@ -0,0 +1,8 @@ +"""Experimental features.""" + +from that_depends.experimental.providers import LazyProvider + + +__all__ = [ + "LazyProvider", +] diff --git a/that_depends/experimental/providers.py b/that_depends/experimental/providers.py new file mode 100644 index 00000000..c79becf6 --- /dev/null +++ b/that_depends/experimental/providers.py @@ -0,0 +1,179 @@ +import importlib +import re +import typing +from typing import Any, TypeVar, cast, overload + +import typing_extensions + +from that_depends import ContextScope +from that_depends.providers import AbstractProvider +from that_depends.providers.context_resources import CT, SupportsContext +from that_depends.providers.mixin import SupportsTeardown + + +T_co = TypeVar("T_co", covariant=True) + +_IMPORT_STRING_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$") + + +class LazyProvider(SupportsTeardown, SupportsContext[Any], AbstractProvider[Any]): + """Lazily imports and provides a provider from a module.""" + + @overload + def __init__(self, *, module_string: str, provider_string: str) -> None: ... + + @overload + def __init__(self, import_string: str, /) -> None: ... + + def __init__( + self, + import_string: str | None = None, + module_string: str | None = None, + provider_string: str | None = None, + ) -> None: + """Initialize a LazyProvider instance. + + Args: + module_string: path to module to import from. + provider_string: path to provider within module. + import_string: path to provider including module and attributes. + + """ + super().__init__() + if (import_string is not None) == (module_string is not None and provider_string is not None): + msg = ( + "You must provide either import_string OR both module_string AND provider_string, " + "but not both or neither." + ) + raise ValueError(msg) + + self._module_string = module_string + self._provider_string = provider_string + self._import_string = import_string + self._check_strings() + self._provider: AbstractProvider[Any] | None = None + + @typing_extensions.override + def get_scope(self) -> ContextScope | None: + provider = self._get_provider() + if isinstance(provider, SupportsContext): + return provider.get_scope() + msg = "Underlying provider does not support context scopes" + raise NotImplementedError(msg) + + @typing_extensions.override + def context_async(self, force: bool = False) -> typing.AsyncContextManager[CT]: + provider = self._get_provider() + if isinstance(provider, SupportsContext): + return provider.context_async(force) + msg = "Underlying provider does not support context management" + raise NotImplementedError(msg) + + @typing_extensions.override + def context_sync(self, force: bool = False) -> typing.ContextManager[CT]: + provider = self._get_provider() + if isinstance(provider, SupportsContext): + return provider.context_sync(force) + msg = "Underlying provider does not support context management" + raise NotImplementedError(msg) + + @typing_extensions.override + def supports_context_sync(self) -> bool: + provider = self._get_provider() + return isinstance(provider, SupportsContext) and provider.supports_context_sync() + + @typing_extensions.override + async def tear_down(self, propagate: bool = True) -> None: + provider = self._get_provider() + if isinstance(provider, SupportsTeardown): + await provider.tear_down(propagate=propagate) + else: + msg = "Underlying provider does not support tear down." + raise NotImplementedError(msg) + + @typing_extensions.override + def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: + provider = self._get_provider() + if isinstance(provider, SupportsTeardown): + provider.tear_down_sync(propagate=propagate, raise_on_async=raise_on_async) + else: + msg = "Underlying provider does not support tear down." + raise NotImplementedError(msg) + + def _check_strings(self) -> None: + if self._import_string is not None and not _IMPORT_STRING_REGEX.match(self._import_string): + msg = f"Invalid import_string '{self._import_string}'" + raise ValueError(msg) + if self._module_string is not None and not _IMPORT_STRING_REGEX.match(self._module_string): + msg = f"Invalid module_string '{self._module_string}'" + raise ValueError(msg) + if self._provider_string is not None and not _IMPORT_STRING_REGEX.match(self._provider_string): + msg = f"Invalid provider_string '{self._provider_string}'" + raise ValueError(msg) + + def _get_provider(self) -> AbstractProvider[Any]: + if self._provider: + return self._provider + if self._import_string is not None: + parts = self._import_string.split(".") + for i in range(len(parts), 0, -1): + module_name = ".".join(parts[:i]) + try: + module = importlib.import_module(module_name) + attrs = parts[i:] + break + except ImportError: + continue + else: + msg = f"Cannot import any module from '{self._import_string}'" + raise ImportError(msg) + else: + if self._module_string is None or self._provider_string is None: + msg = "Invalid state: module_string and provider_string must be set" + raise RuntimeError(msg) + module = importlib.import_module(self._module_string) + attrs = self._provider_string.split(".") + provider = module + for attr in attrs: + provider = getattr(provider, attr) + self._provider = cast(AbstractProvider[Any], provider) + return self._provider + + @typing_extensions.override + async def resolve(self) -> Any: + provider = self._get_provider() + return await provider.resolve() + + @typing_extensions.override + def resolve_sync(self) -> Any: + provider = self._get_provider() + return provider.resolve_sync() + + @typing_extensions.override + def override_sync( + self, mock: object, tear_down_children: bool = False, propagate: bool = True, raise_on_async: bool = False + ) -> None: + provider = self._get_provider() + provider.override_sync(mock, tear_down_children, propagate, raise_on_async) + + @typing_extensions.override + async def override(self, mock: object, tear_down_children: bool = False, propagate: bool = True) -> None: + provider = self._get_provider() + await provider.override(mock, tear_down_children, propagate) + + @typing_extensions.override + async def reset_override(self, tear_down_children: bool = False, propagate: bool = True) -> None: + provider = self._get_provider() + await provider.reset_override(tear_down_children, propagate) + + @typing_extensions.override + def reset_override_sync( + self, tear_down_children: bool = False, propagate: bool = True, raise_on_async: bool = False + ) -> None: + provider = self._get_provider() + provider.reset_override_sync(tear_down_children, propagate, raise_on_async) + + @typing_extensions.override + def __getattr__(self, attr_name: str) -> typing.Any: + provider = self._get_provider() + return getattr(provider, attr_name) diff --git a/that_depends/meta.py b/that_depends/meta.py index 02944c32..21507a49 100644 --- a/that_depends/meta.py +++ b/that_depends/meta.py @@ -43,7 +43,7 @@ def __setitem__(self, key: str, value: typing.Any) -> None: super().__setitem__(key, value) -class BaseContainerMeta(SupportsContext[None], abc.ABCMeta): +class BaseContainerMeta(abc.ABCMeta, SupportsContext[None]): """Metaclass for BaseContainer.""" @override diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 52820ddc..217dc570 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -1,4 +1,3 @@ -import abc import asyncio import contextlib import inspect @@ -106,7 +105,7 @@ def _enter_named_scope(scope: ContextScope) -> typing.Iterator[ContextScope]: CT = typing.TypeVar("CT") -class SupportsContext(abc.ABC, typing.Generic[CT]): +class SupportsContext(typing.Generic[CT]): """Interface for resources that support context initialization. This interface defines methods to create synchronous and asynchronous