diff --git a/tenacity/__init__.py b/tenacity/__init__.py index 64ecaf22..abcd6615 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -25,7 +25,8 @@ import warnings from abc import ABC, abstractmethod from concurrent import futures -from inspect import iscoroutinefunction + +from . import _utils # Import all built-in retry strategies for easier usage. from .retry import retry_base # noqa @@ -88,6 +89,7 @@ if t.TYPE_CHECKING: import types + from . import asyncio as tasyncio from .retry import RetryBaseT from .stop import StopBaseT from .wait import WaitBaseT @@ -556,16 +558,16 @@ def retry(func: WrappedFn) -> WrappedFn: @t.overload def retry( - sleep: t.Callable[[t.Union[int, float]], t.Optional[t.Awaitable[None]]] = sleep, + sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = sleep, stop: "StopBaseT" = stop_never, wait: "WaitBaseT" = wait_none(), - retry: "RetryBaseT" = retry_if_exception_type(), - before: t.Callable[["RetryCallState"], None] = before_nothing, - after: t.Callable[["RetryCallState"], None] = after_nothing, - before_sleep: t.Optional[t.Callable[["RetryCallState"], None]] = None, + retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(), + before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, + after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, + before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, reraise: bool = False, retry_error_cls: t.Type["RetryError"] = RetryError, - retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Any]] = None, + retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]] = None, ) -> t.Callable[[WrappedFn], WrappedFn]: ... @@ -588,7 +590,7 @@ def wrap(f: WrappedFn) -> WrappedFn: f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)" ) r: "BaseRetrying" - if iscoroutinefunction(f): + if _utils.is_coroutine_callable(f): r = AsyncRetrying(*dargs, **dkw) elif tornado and hasattr(tornado.gen, "is_coroutine_function") and tornado.gen.is_coroutine_function(f): r = TornadoRetrying(*dargs, **dkw) @@ -600,7 +602,7 @@ def wrap(f: WrappedFn) -> WrappedFn: return wrap -from tenacity._asyncio import AsyncRetrying # noqa:E402,I100 +from tenacity.asyncio import AsyncRetrying # noqa:E402,I100 if tornado: from tenacity.tornadoweb import TornadoRetrying diff --git a/tenacity/_utils.py b/tenacity/_utils.py index f14ff320..f4c817dd 100644 --- a/tenacity/_utils.py +++ b/tenacity/_utils.py @@ -13,9 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import sys import typing +import typing as t from datetime import timedelta @@ -74,3 +75,22 @@ def get_callback_name(cb: typing.Callable[..., typing.Any]) -> str: def to_seconds(time_unit: time_unit_type) -> float: return float(time_unit.total_seconds() if isinstance(time_unit, timedelta) else time_unit) + + +def is_coroutine_callable(call: t.Callable[..., t.Any]) -> bool: + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): + return False + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.iscoroutinefunction(dunder_call) + + +def wrap_to_async_func(call: t.Callable[..., t.Any]) -> t.Callable[..., t.Awaitable[t.Any]]: + if is_coroutine_callable(call): + return call + + async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: + return call(*args, **kwargs) + + return inner diff --git a/tenacity/_asyncio.py b/tenacity/asyncio/__init__.py similarity index 62% rename from tenacity/_asyncio.py rename to tenacity/asyncio/__init__.py index 9d418a8c..d9f388a4 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/asyncio/__init__.py @@ -15,37 +15,64 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import functools -import inspect import sys import typing as t -from asyncio import sleep +import tenacity from tenacity import AttemptManager from tenacity import BaseRetrying from tenacity import DoAttempt from tenacity import DoSleep from tenacity import RetryCallState +from tenacity import RetryError +from tenacity import _utils +from tenacity import after_nothing +from tenacity import before_nothing + +# Import all built-in retry strategies for easier usage. +from .retry import RetryBaseT +from .retry import retry_all # noqa +from .retry import retry_any # noqa +from .retry import retry_if_exception # noqa +from .retry import retry_if_result # noqa +from ..retry import RetryBaseT as SyncRetryBaseT + +if t.TYPE_CHECKING: + from tenacity.stop import StopBaseT + from tenacity.wait import WaitBaseT WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) -def is_coroutine_callable(call: t.Callable[..., t.Any]) -> bool: - if inspect.isroutine(call): - return inspect.iscoroutinefunction(call) - if inspect.isclass(call): - return False - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return inspect.iscoroutinefunction(dunder_call) - - class AsyncRetrying(BaseRetrying): - sleep: t.Callable[[float], t.Awaitable[t.Any]] - - def __init__(self, sleep: t.Callable[[float], t.Awaitable[t.Any]] = sleep, **kwargs: t.Any) -> None: - super().__init__(**kwargs) - self.sleep = sleep + def __init__( + self, + sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = asyncio.sleep, + stop: "StopBaseT" = tenacity.stop.stop_never, + wait: "WaitBaseT" = tenacity.wait.wait_none(), + retry: "t.Union[SyncRetryBaseT, RetryBaseT]" = tenacity.retry_if_exception_type(), + before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, + after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, + before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, + reraise: bool = False, + retry_error_cls: t.Type["RetryError"] = RetryError, + retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]] = None, + ) -> None: + super().__init__( + sleep=sleep, # type: ignore[arg-type] + stop=stop, + wait=wait, + retry=retry, # type: ignore[arg-type] + before=before, # type: ignore[arg-type] + after=after, # type: ignore[arg-type] + before_sleep=before_sleep, # type: ignore[arg-type] + reraise=reraise, + retry_error_cls=retry_error_cls, + retry_error_callback=retry_error_callback, + ) async def __call__( # type: ignore[override] self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any @@ -64,29 +91,19 @@ async def __call__( # type: ignore[override] retry_state.set_result(result) elif isinstance(do, DoSleep): retry_state.prepare_for_next_attempt() - await self.sleep(do) + await self.sleep(do) # type: ignore[misc] else: return do # type: ignore[no-any-return] - @classmethod - def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - if is_coroutine_callable(fn): - return fn - - async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: - return fn(*args, **kwargs) - - return inner - def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: - self.iter_state["actions"].append(self._wrap_action_func(fn)) + self.iter_state["actions"].append(_utils.wrap_to_async_func(fn)) async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override] - self.iter_state["retry_run_result"] = await self._wrap_action_func(self.retry)(retry_state) + self.iter_state["retry_run_result"] = await _utils.wrap_to_async_func(self.retry)(retry_state) async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override] if self.wait: - sleep = await self._wrap_action_func(self.wait)(retry_state) + sleep = await _utils.wrap_to_async_func(self.wait)(retry_state) else: sleep = 0.0 @@ -94,7 +111,7 @@ async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignor async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override] self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start - self.iter_state["stop_run_result"] = await self._wrap_action_func(self.stop)(retry_state) + self.iter_state["stop_run_result"] = await _utils.wrap_to_async_func(self.stop)(retry_state) async def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa: A003 self._begin_iter(retry_state) @@ -120,7 +137,7 @@ async def __anext__(self) -> AttemptManager: return AttemptManager(retry_state=self._retry_state) elif isinstance(do, DoSleep): self._retry_state.prepare_for_next_attempt() - await self.sleep(do) + await self.sleep(do) # type: ignore[misc] else: raise StopAsyncIteration @@ -137,3 +154,13 @@ async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined] return async_wrapped # type: ignore[return-value] + + +__all__ = [ + "retry_all", + "retry_any", + "retry_if_exception", + "retry_if_result", + "WrappedFn", + "AsyncRetrying", +] diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py new file mode 100644 index 00000000..7e00e2d1 --- /dev/null +++ b/tenacity/asyncio/retry.py @@ -0,0 +1,95 @@ +# Copyright 2016–2021 Julien Danjou +# Copyright 2016 Joshua Harlow +# Copyright 2013-2014 Ray Holder +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import inspect +import typing + +from tenacity import _utils +from tenacity import retry_base +from tenacity import retry_if_exception as _retry_if_exception +from tenacity import retry_if_result as _retry_if_result + +if typing.TYPE_CHECKING: + from tenacity import RetryCallState + + +class async_retry_base(retry_base): + """Abstract base class for async retry strategies.""" + + @abc.abstractmethod + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + pass + + def __and__(self, other: "typing.Union[retry_base, async_retry_base]") -> "retry_all": # type: ignore[override] + return retry_all(self, other) + + def __or__(self, other: "typing.Union[retry_base, async_retry_base]") -> "retry_any": # type: ignore[override] + return retry_any(self, other) + + +class async_predicate_mixin: + async def __call__(self, retry_state: "RetryCallState") -> bool: + result = super().__call__(retry_state) # type: ignore[misc] + if inspect.isawaitable(result): + result = await result + return typing.cast(bool, result) + + +RetryBaseT = typing.Union[async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]] + + +class retry_if_exception(async_predicate_mixin, _retry_if_exception, async_retry_base): # type: ignore[misc] + """Retry strategy that retries if an exception verifies a predicate.""" + + def __init__(self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]]) -> None: + super().__init__(predicate) # type: ignore[arg-type] + + +class retry_if_result(async_predicate_mixin, _retry_if_result, async_retry_base): # type: ignore[misc] + """Retries if the result verifies a predicate.""" + + def __init__(self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]) -> None: + super().__init__(predicate) # type: ignore[arg-type] + + +class retry_any(async_retry_base): + """Retries if any of the retries condition is valid.""" + + def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None: + self.retries = retries + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + result = False + for r in self.retries: + result = result or await _utils.wrap_to_async_func(r)(retry_state) + if result: + break + return result + + +class retry_all(async_retry_base): + """Retries if all the retries condition are valid.""" + + def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None: + self.retries = retries + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + result = True + for r in self.retries: + result = result and await _utils.wrap_to_async_func(r)(retry_state) + if not result: + break + return result diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 542f540d..c425ef4b 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -22,7 +22,7 @@ import tenacity from tenacity import AsyncRetrying, RetryError -from tenacity import _asyncio as tasyncio +from tenacity import asyncio as tasyncio from tenacity import retry, retry_if_result, stop_after_attempt from tenacity.wait import wait_fixed @@ -191,6 +191,60 @@ def lt_3(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result_or(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + class CustomException(Exception): + pass + + async def is_exc(e: BaseException) -> bool: + return isinstance(e, CustomException) + + retry_strategy = tasyncio.retry_if_result(lt_3) | tasyncio.retry_if_exception(is_exc) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + if 1 < attempts < 3: + raise CustomException() + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_and(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + def gt_0(x: float) -> bool: + return x > 0 + + retry_strategy = tasyncio.retry_if_result(lt_3) & retry_if_result(gt_0) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + @asynctest async def test_async_retying_iterator(self): thing = NoIOErrorAfterCount(5)