Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/redis_lock/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from base64 import b64encode
from logging import getLogger
from os import urandom
from time import sleep
from typing import Union

from redis import StrictRedis
Expand Down Expand Up @@ -104,7 +105,7 @@ class Lock(object):
_lock_renewal_interval: float
_lock_renewal_thread: Union[threading.Thread, None]

def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000, blocking=True):
def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, max_renewal_count=3, strict=True, signal_expire=1000, blocking=True):
"""
:param redis_client:
An instance of :class:`~StrictRedis`.
Expand All @@ -128,6 +129,8 @@ def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False,
an interval of ``expire*2/3``. If wishing to use a different renewal
time, subclass Lock, call ``super().__init__()`` then set
``self._lock_renewal_interval`` to your desired interval.
:param max_renewal_count:
To avoid locking indefinitely use this param to set the maximum number of times to renew the lock.
:param strict:
If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``.
:param signal_expire:
Expand Down Expand Up @@ -167,6 +170,7 @@ def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False,
self._signal = 'lock-signal:' + name
self._lock_renewal_interval = float(expire) * 2 / 3 if auto_renewal else None
self._lock_renewal_thread = None
self._max_renewal_count = max_renewal_count

self.blocking = blocking

Expand Down Expand Up @@ -270,18 +274,24 @@ def extend(self, expire=None):
raise RuntimeError(f"Unsupported error code {error} from EXTEND script")

@staticmethod
def _lock_renewer(name, lockref, interval, stop):
def _lock_renewer(name, lockref, interval, max_renewal_count):
"""
Renew the lock key in redis every `interval` seconds for as long
as `self._lock_renewal_thread.should_exit` is False.
"""
while not stop.wait(timeout=interval):
renewal_count = 0
while renewal_count < max_renewal_count:
if renewal_count >= max_renewal_count:
logger_for_refresh_thread.debug("Stopping loop because Lock(%r) was renewed max number of times.", name)
break
sleep(interval)
logger_for_refresh_thread.debug("Refreshing Lock(%r).", name)
lock: "Lock" = lockref()
if lock is None:
logger_for_refresh_thread.debug("Stopping loop because Lock(%r) was garbage collected.", name)
break
lock.extend(expire=lock._expire)
renewal_count += 1
del lock
logger_for_refresh_thread.debug("Exiting renewal thread for Lock(%r).", name)

Expand All @@ -303,7 +313,7 @@ def _start_lock_renewer(self):
'name': self._name,
'lockref': weakref.ref(self),
'interval': self._lock_renewal_interval,
'stop': self._lock_renewal_stop,
'max_renewal_count': self._max_renewal_count
},
)
self._lock_renewal_thread.daemon = True
Expand Down
17 changes: 17 additions & 0 deletions tests/test_redis_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,23 @@ def test_auto_renewal(conn):
assert lock._lock_renewal_thread is None


def test_auto_renewal_with_max_renew_count(conn):
lock = Lock(conn, 'lock_renewal', expire=3, auto_renewal=True, max_renewal_count=2)
lock.acquire()

assert isinstance(lock._lock_renewal_thread, threading.Thread)
assert not lock._lock_renewal_stop.is_set()
assert isinstance(lock._lock_renewal_interval, float)
assert lock._lock_renewal_interval == 2

time.sleep(8)
assert lock.locked() is False

with pytest.raises(NotAcquired) as exc:
lock.release()
assert "is not acquired or it already expired" in str(exc)


@pytest.mark.parametrize('signal_expire', [1000, 1500])
@pytest.mark.parametrize('method', ['release', 'reset_all'])
def test_signal_expiration(conn, signal_expire, method):
Expand Down