diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 49a8982..39eb8cf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,6 +33,9 @@ jobs: - "3.12" - "3.13" - "3.14" + faststream-version: + - "<0.6.0" + - ">=0.6.0" steps: - uses: actions/checkout@v5 - uses: extractions/setup-just@v3 @@ -41,7 +44,7 @@ jobs: cache-dependency-glob: "**/pyproject.toml" - run: uv python install ${{ matrix.python-version }} - run: just install - - run: just test . --cov=. --cov-report xml + - run: uv run --with "faststream${{ matrix.faststream-version }}" pytest . --cov=. --cov-report xml - uses: codecov/codecov-action@v5.4.3 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/pyproject.toml b/pyproject.toml index c68b78a..398e1a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ fastapi = [ "fastapi", ] faststream = [ - "faststream<0.6.0" + "faststream" ] [project.urls] diff --git a/tests/integrations/faststream/test_faststream_di_pass_message.py b/tests/integrations/faststream/test_faststream_di_pass_message.py index cc84c69..0ca62d1 100644 --- a/tests/integrations/faststream/test_faststream_di_pass_message.py +++ b/tests/integrations/faststream/test_faststream_di_pass_message.py @@ -1,11 +1,18 @@ import typing from faststream import BaseMiddleware, Context, Depends -from faststream.broker.message import StreamMessage from faststream.nats import NatsBroker, TestNatsBroker from faststream.nats.message import NatsMessage +from packaging.version import Version from that_depends import BaseContainer, container_context, fetch_context_item, providers +from that_depends.integrations.faststream import _FASTSTREAM_VERSION + + +if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover + from faststream.message import StreamMessage +else: # pragma: no cover + from faststream.broker.message import StreamMessage # type: ignore[import-not-found, no-redef] class ContextMiddleware(BaseMiddleware): @@ -18,7 +25,11 @@ async def consume_scope( return await call_next(msg) -broker = NatsBroker(middlewares=(ContextMiddleware,), validate=False) +if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover + broker = NatsBroker(middlewares=(ContextMiddleware,)) + +else: # pragma: no cover + broker = NatsBroker(middlewares=(ContextMiddleware,), validate=False) # type: ignore[call-arg] TEST_SUBJECT = "test" diff --git a/that_depends/integrations/faststream.py b/that_depends/integrations/faststream.py index e59ec3d..67d9fc6 100644 --- a/that_depends/integrations/faststream.py +++ b/that_depends/integrations/faststream.py @@ -1,58 +1,136 @@ import typing +from importlib.metadata import version from types import TracebackType -from typing import Any, Optional +from typing import Any, Final, Optional -from faststream import BaseMiddleware -from typing_extensions import override +from packaging.version import Version +from typing_extensions import deprecated, override from that_depends import container_context from that_depends.providers.context_resources import ContextScope, SupportsContext from that_depends.utils import UNSET, Unset, is_set -class DIContextMiddleware(BaseMiddleware): - """Initializes the container context for faststream brokers.""" - - def __init__( - self, - *context_items: SupportsContext[Any], - global_context: dict[str, Any] | Unset = UNSET, - scope: ContextScope | Unset = UNSET, - ) -> None: - """Initialize the container context middleware. - - Args: - *context_items (SupportsContext[Any]): Context items to initialize. - global_context (dict[str, Any] | Unset): Global context to initialize the container. - scope (ContextScope | Unset): Context scope to initialize the container. - - """ - super().__init__() - self._context: container_context | None = None - self._context_items = set(context_items) - self._global_context = global_context - self._scope = scope - - @override - async def on_receive(self) -> None: - self._context = container_context( - *self._context_items, - scope=self._scope if is_set(self._scope) else None, - global_context=self._global_context if is_set(self._global_context) else None, - ) - await self._context.__aenter__() - - @override - async def after_processed( - self, - exc_type: type[BaseException] | None = None, - exc_val: BaseException | None = None, - exc_tb: Optional["TracebackType"] = None, - ) -> bool | None: - if self._context is not None: - await self._context.__aexit__(exc_type, exc_val, exc_tb) - return None - - def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> "DIContextMiddleware": # noqa: ARG002, ANN401 - """Create an instance of DIContextMiddleware.""" - return DIContextMiddleware(*self._context_items, scope=self._scope, global_context=self._global_context) +_FASTSTREAM_MODULE_NAME: Final[str] = "faststream" +_FASTSTREAM_VERSION: Final[str] = version(_FASTSTREAM_MODULE_NAME) +if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover + from faststream import BaseMiddleware, ContextRepo + from faststream._internal.types import AnyMsg + + class DIContextMiddleware(BaseMiddleware): + """Initializes the container context for faststream brokers.""" + + def __init__( + self, + *context_items: SupportsContext[Any], + msg: AnyMsg | None = None, + context: Optional["ContextRepo"] = None, + global_context: dict[str, Any] | Unset = UNSET, + scope: ContextScope | Unset = UNSET, + ) -> None: + """Initialize the container context middleware. + + Args: + *context_items (SupportsContext[Any]): Context items to initialize. + msg (Any): Message object. + context (ContextRepo): Context repository. + global_context (dict[str, Any] | Unset): Global context to initialize the container. + scope (ContextScope | Unset): Context scope to initialize the container. + + """ + super().__init__(msg, context=context) # type: ignore[arg-type] + self._context: container_context | None = None + self._context_items = set(context_items) + self._global_context = global_context + self._scope = scope + + @override + async def on_receive(self) -> None: + self._context = container_context( + *self._context_items, + scope=self._scope if is_set(self._scope) else None, + global_context=self._global_context if is_set(self._global_context) else None, + ) + await self._context.__aenter__() + + @override + async def after_processed( + self, + exc_type: type[BaseException] | None = None, + exc_val: BaseException | None = None, + exc_tb: Optional["TracebackType"] = None, + ) -> bool | None: + if self._context is not None: + await self._context.__aexit__(exc_type, exc_val, exc_tb) + return None + + def __call__(self, msg: Any = None, **kwargs: Any) -> "DIContextMiddleware": # noqa: ANN401 + """Create an instance of DIContextMiddleware. + + Args: + msg (Any): Message object. + **kwargs: Additional keyword arguments. + + Returns: + DIContextMiddleware: A new instance of DIContextMiddleware. + + """ + context = kwargs.get("context") + + return DIContextMiddleware( + *self._context_items, + msg=msg, + context=context, + scope=self._scope, + global_context=self._global_context, + ) +else: # pragma: no cover + from faststream import BaseMiddleware + + @deprecated("Will be removed with faststream v1") + class DIContextMiddleware(BaseMiddleware): # type: ignore[no-redef] + """Initializes the container context for faststream brokers.""" + + def __init__( + self, + *context_items: SupportsContext[Any], + global_context: dict[str, Any] | Unset = UNSET, + scope: ContextScope | Unset = UNSET, + ) -> None: + """Initialize the container context middleware. + + Args: + *context_items (SupportsContext[Any]): Context items to initialize. + global_context (dict[str, Any] | Unset): Global context to initialize the container. + scope (ContextScope | Unset): Context scope to initialize the container. + + """ + super().__init__() # type: ignore[call-arg] + self._context: container_context | None = None + self._context_items = set(context_items) + self._global_context = global_context + self._scope = scope + + @override + async def on_receive(self) -> None: + self._context = container_context( + *self._context_items, + scope=self._scope if is_set(self._scope) else None, + global_context=self._global_context if is_set(self._global_context) else None, + ) + await self._context.__aenter__() + + @override + async def after_processed( + self, + exc_type: type[BaseException] | None = None, + exc_val: BaseException | None = None, + exc_tb: Optional["TracebackType"] = None, + ) -> bool | None: + if self._context is not None: + await self._context.__aexit__(exc_type, exc_val, exc_tb) + return None + + def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> "DIContextMiddleware": # noqa: ARG002, ANN401 + """Create an instance of DIContextMiddleware.""" + return DIContextMiddleware(*self._context_items, scope=self._scope, global_context=self._global_context)