diff --git a/src/redis_lock/__init__.py b/src/redis_lock/__init__.py index d187802..9468f60 100644 --- a/src/redis_lock/__init__.py +++ b/src/redis_lock/__init__.py @@ -3,6 +3,7 @@ from base64 import b64encode from logging import getLogger from os import urandom +from time import sleep from typing import Union from redis import StrictRedis @@ -104,7 +105,7 @@ class Lock(object): _lock_renewal_interval: float _lock_renewal_thread: Union[threading.Thread, None] - def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000, blocking=True): + def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, max_renewal_count=3, strict=True, signal_expire=1000, blocking=True): """ :param redis_client: An instance of :class:`~StrictRedis`. @@ -128,6 +129,8 @@ def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, an interval of ``expire*2/3``. If wishing to use a different renewal time, subclass Lock, call ``super().__init__()`` then set ``self._lock_renewal_interval`` to your desired interval. + :param max_renewal_count: + To avoid locking indefinitely use this param to set the maximum number of times to renew the lock. :param strict: If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``. :param signal_expire: @@ -167,6 +170,7 @@ def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, self._signal = 'lock-signal:' + name self._lock_renewal_interval = float(expire) * 2 / 3 if auto_renewal else None self._lock_renewal_thread = None + self._max_renewal_count = max_renewal_count self.blocking = blocking @@ -270,18 +274,24 @@ def extend(self, expire=None): raise RuntimeError(f"Unsupported error code {error} from EXTEND script") @staticmethod - def _lock_renewer(name, lockref, interval, stop): + def _lock_renewer(name, lockref, interval, max_renewal_count): """ Renew the lock key in redis every `interval` seconds for as long as `self._lock_renewal_thread.should_exit` is False. """ - while not stop.wait(timeout=interval): + renewal_count = 0 + while renewal_count < max_renewal_count: + if renewal_count >= max_renewal_count: + logger_for_refresh_thread.debug("Stopping loop because Lock(%r) was renewed max number of times.", name) + break + sleep(interval) logger_for_refresh_thread.debug("Refreshing Lock(%r).", name) lock: "Lock" = lockref() if lock is None: logger_for_refresh_thread.debug("Stopping loop because Lock(%r) was garbage collected.", name) break lock.extend(expire=lock._expire) + renewal_count += 1 del lock logger_for_refresh_thread.debug("Exiting renewal thread for Lock(%r).", name) @@ -303,7 +313,7 @@ def _start_lock_renewer(self): 'name': self._name, 'lockref': weakref.ref(self), 'interval': self._lock_renewal_interval, - 'stop': self._lock_renewal_stop, + 'max_renewal_count': self._max_renewal_count }, ) self._lock_renewal_thread.daemon = True diff --git a/tests/test_redis_lock.py b/tests/test_redis_lock.py index ac9e3ef..be189e9 100644 --- a/tests/test_redis_lock.py +++ b/tests/test_redis_lock.py @@ -542,6 +542,23 @@ def test_auto_renewal(conn): assert lock._lock_renewal_thread is None +def test_auto_renewal_with_max_renew_count(conn): + lock = Lock(conn, 'lock_renewal', expire=3, auto_renewal=True, max_renewal_count=2) + lock.acquire() + + assert isinstance(lock._lock_renewal_thread, threading.Thread) + assert not lock._lock_renewal_stop.is_set() + assert isinstance(lock._lock_renewal_interval, float) + assert lock._lock_renewal_interval == 2 + + time.sleep(8) + assert lock.locked() is False + + with pytest.raises(NotAcquired) as exc: + lock.release() + assert "is not acquired or it already expired" in str(exc) + + @pytest.mark.parametrize('signal_expire', [1000, 1500]) @pytest.mark.parametrize('method', ['release', 'reset_all']) def test_signal_expiration(conn, signal_expire, method):