Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions packages/modern-di-fastapi/modern_di_fastapi/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import contextlib
import dataclasses
import typing

import fastapi
from fastapi.routing import _merge_lifespan_context
from modern_di import Container, Scope, providers
from starlette.requests import HTTPConnection

Expand All @@ -17,13 +19,27 @@ def fetch_di_container(app_: fastapi.FastAPI) -> Container:
return typing.cast(Container, app_.state.di_container)


@contextlib.asynccontextmanager
async def _lifespan_manager(app_: fastapi.FastAPI) -> typing.AsyncIterator[None]:
container = fetch_di_container(app_)
try:
yield
finally:
await container.close_async()


def setup_di(app: fastapi.FastAPI, container: Container) -> Container:
app.state.di_container = container
container.providers_registry.add_providers(fastapi_request=fastapi_request, fastapi_websocket=fastapi_websocket)
old_lifespan_manager = app.router.lifespan_context
app.router.lifespan_context = _merge_lifespan_context(
old_lifespan_manager,
_lifespan_manager,
)
return container


async def build_di_container(connection: HTTPConnection) -> Container:
async def build_di_container(connection: HTTPConnection) -> typing.AsyncIterator[Container]:
context: dict[type[typing.Any], typing.Any] = {}
scope: Scope | None = None
if isinstance(connection, fastapi.Request):
Expand All @@ -32,7 +48,11 @@ async def build_di_container(connection: HTTPConnection) -> Container:
elif isinstance(connection, fastapi.WebSocket):
context[fastapi.WebSocket] = connection
scope = Scope.SESSION
return fetch_di_container(connection.app).build_child_container(context=context, scope=scope)
container = fetch_di_container(connection.app).build_child_container(context=context, scope=scope)
try:
yield container
finally:
await container.close_async()


@dataclasses.dataclass(slots=True, frozen=True)
Expand Down
13 changes: 8 additions & 5 deletions packages/modern-di-faststream/modern_di_faststream/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ async def consume_scope(
request_container = self.di_container.build_child_container(
scope=modern_di.Scope.REQUEST, context={faststream.StreamMessage: msg}
)
with self.faststream_context.scope("request_container", request_container):
return typing.cast(
typing.AsyncIterator[DecodedMessage],
await call_next(msg),
)
try:
with self.faststream_context.scope("request_container", request_container):
return typing.cast(
typing.AsyncIterator[DecodedMessage],
await call_next(msg),
)
finally:
await request_container.close_async()

if _OLD_MIDDLEWARES: # pragma: no cover

Expand Down
23 changes: 19 additions & 4 deletions packages/modern-di-litestar/modern_di_litestar/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import dataclasses
import typing

Expand All @@ -22,6 +23,15 @@ def fetch_di_container(app_: litestar.Litestar) -> Container:
return typing.cast(Container, app_.state.di_container)


@contextlib.asynccontextmanager
async def _lifespan_manager(app_: litestar.Litestar) -> typing.AsyncIterator[None]:
container = fetch_di_container(app_)
try:
yield
finally:
await container.close_async()


class ModernDIPlugin(InitPlugin):
__slots__ = ("container",)

Expand All @@ -33,13 +43,14 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig:
litestar_request=litestar_request, litestar_websocket=litestar_websocket
)
app_config.state.di_container = self.container
app_config.dependencies["di_container"] = Provide(build_di_container, sync_to_thread=False)
app_config.dependencies["di_container"] = Provide(build_di_container)
app_config.lifespan.append(_lifespan_manager)
return app_config


def build_di_container(
async def build_di_container(
request: litestar.Request[typing.Any, typing.Any, typing.Any],
) -> Container:
) -> typing.AsyncIterator[Container]:
context: dict[type[typing.Any], typing.Any] = {}
scope: DIScope | None
if isinstance(request, litestar.WebSocket):
Expand All @@ -48,7 +59,11 @@ def build_di_container(
else:
context[litestar.Request] = request
scope = DIScope.REQUEST
return fetch_di_container(request.app).build_child_container(context=context, scope=scope)
container = fetch_di_container(request.app).build_child_container(context=context, scope=scope)
try:
yield container
finally:
await container.close_async()


@dataclasses.dataclass(slots=True, frozen=True)
Expand Down
6 changes: 6 additions & 0 deletions packages/modern-di/modern_di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def resolve(self, dependency_type: type[T_co] | None = None, *, dependency_name:

return typing.cast(T_co, provider.resolve(self))

async def close_async(self) -> None:
await self.cache_registry.close_async()

def close_sync(self) -> None:
self.cache_registry.close_sync()

def resolve_provider(self, provider: "AbstractProvider[T_co]") -> T_co:
return typing.cast(T_co, provider.resolve(self.find_container(provider.scope)))

Expand Down
9 changes: 7 additions & 2 deletions packages/modern-di/modern_di/providers/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import inspect
import typing

from modern_di import types
Expand All @@ -11,10 +12,14 @@
from modern_di import Container


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
@dataclasses.dataclass(kw_only=True, slots=True)
class CacheSettings(typing.Generic[types.T_co]):
clear_cache: bool = True
finalizer: typing.Callable[[types.T_co], None | typing.Coroutine[None, None, None]] | None = None
finalizer: typing.Callable[[types.T_co], None | typing.Awaitable[None]] | None = None
is_async_finalizer: bool = dataclasses.field(init=False)

def __post_init__(self) -> None:
self.is_async_finalizer = bool(self.finalizer) and inspect.iscoroutinefunction(self.finalizer)


class Factory(AbstractProvider[types.T_co]):
Expand Down
37 changes: 37 additions & 0 deletions packages/modern-di/modern_di/registries/cache_registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import typing
import warnings

from modern_di import types
from modern_di.providers import CacheSettings, Factory
Expand All @@ -11,10 +12,46 @@ class CacheItem:
cache: typing.Any | None = None
kwargs: dict[str, typing.Any] | None = None

def _clear(self) -> None:
if self.settings and self.settings.clear_cache:
self.cache = None

self.kwargs = None

async def close_async(self) -> None:
if self.cache and self.settings and self.settings.finalizer:
if self.settings.is_async_finalizer:
await self.settings.finalizer(self.cache) # type: ignore[misc]
else:
self.settings.finalizer(self.cache)

self._clear()

def close_sync(self) -> None:
if self.cache and self.settings and self.settings.finalizer:
if self.settings.is_async_finalizer:
warnings.warn(
f"Calling `close_sync` for async finalizer, type={type(self.cache)}",
RuntimeWarning,
stacklevel=2,
)
return
self.settings.finalizer(self.cache)

self._clear()


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class CacheRegistry:
_items: dict[str, CacheItem] = dataclasses.field(init=False, default_factory=dict)

def fetch_cache_item(self, provider: Factory[types.T_co]) -> CacheItem:
return self._items.setdefault(provider.provider_id, CacheItem(settings=provider.cache_settings))

async def close_async(self) -> None:
for cache_item in self._items.values():
await cache_item.close_async()

def close_sync(self) -> None:
for cache_item in self._items.values():
cache_item.close_sync()
29 changes: 26 additions & 3 deletions packages/modern-di/tests_core/providers/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,22 @@ class DependentCreator:
dep1: SimpleCreator


def sync_finalizer(_: SimpleCreator) -> None:
pass


async def async_finalizer(_: DependentCreator) -> None:
pass


class MyGroup(Group):
app_singleton = providers.Factory(
creator=SimpleCreator, kwargs={"dep1": "original"}, cache_settings=providers.CacheSettings()
creator=SimpleCreator,
kwargs={"dep1": "original"},
cache_settings=providers.CacheSettings(clear_cache=False, finalizer=sync_finalizer),
)
request_singleton = providers.Factory(
scope=Scope.REQUEST, creator=DependentCreator, cache_settings=providers.CacheSettings()
scope=Scope.REQUEST, creator=DependentCreator, cache_settings=providers.CacheSettings(finalizer=async_finalizer)
)


Expand All @@ -31,9 +41,12 @@ def test_app_singleton() -> None:
singleton1 = app_container.resolve_provider(MyGroup.app_singleton)
singleton2 = app_container.resolve_provider(MyGroup.app_singleton)
assert singleton1 is singleton2
app_container.close_sync()
cache_item = app_container.cache_registry.fetch_cache_item(MyGroup.app_singleton)
assert cache_item.cache


def test_request_singleton() -> None:
async def test_request_singleton() -> None:
app_container = Container(groups=[MyGroup])
request_container = app_container.build_child_container(scope=Scope.REQUEST)
instance1 = request_container.resolve_provider(MyGroup.request_singleton)
Expand All @@ -46,6 +59,16 @@ def test_request_singleton() -> None:
assert instance3 is instance4
assert instance1 is not instance3

cache_item = request_container.cache_registry.fetch_cache_item(MyGroup.request_singleton)

with pytest.warns(RuntimeWarning, match="Calling `close_sync` for async finalizer"):
request_container.close_sync()

assert cache_item.cache
await request_container.close_async()

assert cache_item.cache is None


def test_app_singleton_in_request_scope() -> None:
app_container = Container(groups=[MyGroup])
Expand Down