From 2292c3c1e73352c3991596a6e0798b2e4f2d1ac0 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Mon, 12 Jan 2026 16:39:34 +0300 Subject: [PATCH] implement factories finalization --- .../modern_di_fastapi/main.py | 24 +++++++++++- .../modern_di_faststream/main.py | 13 ++++--- .../modern_di_litestar/main.py | 23 ++++++++++-- packages/modern-di/modern_di/container.py | 6 +++ .../modern-di/modern_di/providers/factory.py | 9 ++++- .../modern_di/registries/cache_registry.py | 37 +++++++++++++++++++ .../tests_core/providers/test_singleton.py | 29 +++++++++++++-- 7 files changed, 125 insertions(+), 16 deletions(-) diff --git a/packages/modern-di-fastapi/modern_di_fastapi/main.py b/packages/modern-di-fastapi/modern_di_fastapi/main.py index c08abc5..eee7edb 100644 --- a/packages/modern-di-fastapi/modern_di_fastapi/main.py +++ b/packages/modern-di-fastapi/modern_di_fastapi/main.py @@ -1,7 +1,9 @@ +import contextlib import dataclasses import typing import fastapi +from fastapi.routing import _merge_lifespan_context from modern_di import Container, Scope, providers from starlette.requests import HTTPConnection @@ -17,13 +19,27 @@ def fetch_di_container(app_: fastapi.FastAPI) -> Container: return typing.cast(Container, app_.state.di_container) +@contextlib.asynccontextmanager +async def _lifespan_manager(app_: fastapi.FastAPI) -> typing.AsyncIterator[None]: + container = fetch_di_container(app_) + try: + yield + finally: + await container.close_async() + + def setup_di(app: fastapi.FastAPI, container: Container) -> Container: app.state.di_container = container container.providers_registry.add_providers(fastapi_request=fastapi_request, fastapi_websocket=fastapi_websocket) + old_lifespan_manager = app.router.lifespan_context + app.router.lifespan_context = _merge_lifespan_context( + old_lifespan_manager, + _lifespan_manager, + ) return container -async def build_di_container(connection: HTTPConnection) -> Container: +async def build_di_container(connection: HTTPConnection) -> typing.AsyncIterator[Container]: context: dict[type[typing.Any], typing.Any] = {} scope: Scope | None = None if isinstance(connection, fastapi.Request): @@ -32,7 +48,11 @@ async def build_di_container(connection: HTTPConnection) -> Container: elif isinstance(connection, fastapi.WebSocket): context[fastapi.WebSocket] = connection scope = Scope.SESSION - return fetch_di_container(connection.app).build_child_container(context=context, scope=scope) + container = fetch_di_container(connection.app).build_child_container(context=context, scope=scope) + try: + yield container + finally: + await container.close_async() @dataclasses.dataclass(slots=True, frozen=True) diff --git a/packages/modern-di-faststream/modern_di_faststream/main.py b/packages/modern-di-faststream/modern_di_faststream/main.py index 54139f6..0dee223 100644 --- a/packages/modern-di-faststream/modern_di_faststream/main.py +++ b/packages/modern-di-faststream/modern_di_faststream/main.py @@ -44,11 +44,14 @@ async def consume_scope( request_container = self.di_container.build_child_container( scope=modern_di.Scope.REQUEST, context={faststream.StreamMessage: msg} ) - with self.faststream_context.scope("request_container", request_container): - return typing.cast( - typing.AsyncIterator[DecodedMessage], - await call_next(msg), - ) + try: + with self.faststream_context.scope("request_container", request_container): + return typing.cast( + typing.AsyncIterator[DecodedMessage], + await call_next(msg), + ) + finally: + await request_container.close_async() if _OLD_MIDDLEWARES: # pragma: no cover diff --git a/packages/modern-di-litestar/modern_di_litestar/main.py b/packages/modern-di-litestar/modern_di_litestar/main.py index 835a8fa..4305afc 100644 --- a/packages/modern-di-litestar/modern_di_litestar/main.py +++ b/packages/modern-di-litestar/modern_di_litestar/main.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses import typing @@ -22,6 +23,15 @@ def fetch_di_container(app_: litestar.Litestar) -> Container: return typing.cast(Container, app_.state.di_container) +@contextlib.asynccontextmanager +async def _lifespan_manager(app_: litestar.Litestar) -> typing.AsyncIterator[None]: + container = fetch_di_container(app_) + try: + yield + finally: + await container.close_async() + + class ModernDIPlugin(InitPlugin): __slots__ = ("container",) @@ -33,13 +43,14 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig: litestar_request=litestar_request, litestar_websocket=litestar_websocket ) app_config.state.di_container = self.container - app_config.dependencies["di_container"] = Provide(build_di_container, sync_to_thread=False) + app_config.dependencies["di_container"] = Provide(build_di_container) + app_config.lifespan.append(_lifespan_manager) return app_config -def build_di_container( +async def build_di_container( request: litestar.Request[typing.Any, typing.Any, typing.Any], -) -> Container: +) -> typing.AsyncIterator[Container]: context: dict[type[typing.Any], typing.Any] = {} scope: DIScope | None if isinstance(request, litestar.WebSocket): @@ -48,7 +59,11 @@ def build_di_container( else: context[litestar.Request] = request scope = DIScope.REQUEST - return fetch_di_container(request.app).build_child_container(context=context, scope=scope) + container = fetch_di_container(request.app).build_child_container(context=context, scope=scope) + try: + yield container + finally: + await container.close_async() @dataclasses.dataclass(slots=True, frozen=True) diff --git a/packages/modern-di/modern_di/container.py b/packages/modern-di/modern_di/container.py index 9ad5c45..9f67006 100644 --- a/packages/modern-di/modern_di/container.py +++ b/packages/modern-di/modern_di/container.py @@ -88,6 +88,12 @@ def resolve(self, dependency_type: type[T_co] | None = None, *, dependency_name: return typing.cast(T_co, provider.resolve(self)) + async def close_async(self) -> None: + await self.cache_registry.close_async() + + def close_sync(self) -> None: + self.cache_registry.close_sync() + def resolve_provider(self, provider: "AbstractProvider[T_co]") -> T_co: return typing.cast(T_co, provider.resolve(self.find_container(provider.scope))) diff --git a/packages/modern-di/modern_di/providers/factory.py b/packages/modern-di/modern_di/providers/factory.py index afabbb7..7e55f20 100644 --- a/packages/modern-di/modern_di/providers/factory.py +++ b/packages/modern-di/modern_di/providers/factory.py @@ -1,4 +1,5 @@ import dataclasses +import inspect import typing from modern_di import types @@ -11,10 +12,14 @@ from modern_di import Container -@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +@dataclasses.dataclass(kw_only=True, slots=True) class CacheSettings(typing.Generic[types.T_co]): clear_cache: bool = True - finalizer: typing.Callable[[types.T_co], None | typing.Coroutine[None, None, None]] | None = None + finalizer: typing.Callable[[types.T_co], None | typing.Awaitable[None]] | None = None + is_async_finalizer: bool = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.is_async_finalizer = bool(self.finalizer) and inspect.iscoroutinefunction(self.finalizer) class Factory(AbstractProvider[types.T_co]): diff --git a/packages/modern-di/modern_di/registries/cache_registry.py b/packages/modern-di/modern_di/registries/cache_registry.py index 4e0a9b1..71181e5 100644 --- a/packages/modern-di/modern_di/registries/cache_registry.py +++ b/packages/modern-di/modern_di/registries/cache_registry.py @@ -1,5 +1,6 @@ import dataclasses import typing +import warnings from modern_di import types from modern_di.providers import CacheSettings, Factory @@ -11,6 +12,34 @@ class CacheItem: cache: typing.Any | None = None kwargs: dict[str, typing.Any] | None = None + def _clear(self) -> None: + if self.settings and self.settings.clear_cache: + self.cache = None + + self.kwargs = None + + async def close_async(self) -> None: + if self.cache and self.settings and self.settings.finalizer: + if self.settings.is_async_finalizer: + await self.settings.finalizer(self.cache) # type: ignore[misc] + else: + self.settings.finalizer(self.cache) + + self._clear() + + def close_sync(self) -> None: + if self.cache and self.settings and self.settings.finalizer: + if self.settings.is_async_finalizer: + warnings.warn( + f"Calling `close_sync` for async finalizer, type={type(self.cache)}", + RuntimeWarning, + stacklevel=2, + ) + return + self.settings.finalizer(self.cache) + + self._clear() + @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class CacheRegistry: @@ -18,3 +47,11 @@ class CacheRegistry: def fetch_cache_item(self, provider: Factory[types.T_co]) -> CacheItem: return self._items.setdefault(provider.provider_id, CacheItem(settings=provider.cache_settings)) + + async def close_async(self) -> None: + for cache_item in self._items.values(): + await cache_item.close_async() + + def close_sync(self) -> None: + for cache_item in self._items.values(): + cache_item.close_sync() diff --git a/packages/modern-di/tests_core/providers/test_singleton.py b/packages/modern-di/tests_core/providers/test_singleton.py index 134dc7d..a48779b 100644 --- a/packages/modern-di/tests_core/providers/test_singleton.py +++ b/packages/modern-di/tests_core/providers/test_singleton.py @@ -17,12 +17,22 @@ class DependentCreator: dep1: SimpleCreator +def sync_finalizer(_: SimpleCreator) -> None: + pass + + +async def async_finalizer(_: DependentCreator) -> None: + pass + + class MyGroup(Group): app_singleton = providers.Factory( - creator=SimpleCreator, kwargs={"dep1": "original"}, cache_settings=providers.CacheSettings() + creator=SimpleCreator, + kwargs={"dep1": "original"}, + cache_settings=providers.CacheSettings(clear_cache=False, finalizer=sync_finalizer), ) request_singleton = providers.Factory( - scope=Scope.REQUEST, creator=DependentCreator, cache_settings=providers.CacheSettings() + scope=Scope.REQUEST, creator=DependentCreator, cache_settings=providers.CacheSettings(finalizer=async_finalizer) ) @@ -31,9 +41,12 @@ def test_app_singleton() -> None: singleton1 = app_container.resolve_provider(MyGroup.app_singleton) singleton2 = app_container.resolve_provider(MyGroup.app_singleton) assert singleton1 is singleton2 + app_container.close_sync() + cache_item = app_container.cache_registry.fetch_cache_item(MyGroup.app_singleton) + assert cache_item.cache -def test_request_singleton() -> None: +async def test_request_singleton() -> None: app_container = Container(groups=[MyGroup]) request_container = app_container.build_child_container(scope=Scope.REQUEST) instance1 = request_container.resolve_provider(MyGroup.request_singleton) @@ -46,6 +59,16 @@ def test_request_singleton() -> None: assert instance3 is instance4 assert instance1 is not instance3 + cache_item = request_container.cache_registry.fetch_cache_item(MyGroup.request_singleton) + + with pytest.warns(RuntimeWarning, match="Calling `close_sync` for async finalizer"): + request_container.close_sync() + + assert cache_item.cache + await request_container.close_async() + + assert cache_item.cache is None + def test_app_singleton_in_request_scope() -> None: app_container = Container(groups=[MyGroup])