Skip to content

[Feature Request] Setting expire time in milliseconds #99

@sangaline

Description

@sangaline

It would be convenient to be able to set the expire time in milliseconds instead of seconds. The changes are pretty minimal to switch it over, but it probably makes more sense to add it as an option like expire_n_milliseconds: bool=False or something for backwards compatibility.

Here's a rough sketch of what it would look like to switch to milliseconds:

diff --git a/src/redis_lock/__init__.py b/src/redis_lock/__init__.py
index d4d0caa..6b47104 100644
--- a/src/redis_lock/__init__.py
+++ b/src/redis_lock/__init__.py
@@ -36,7 +36,7 @@ EXTEND_SCRIPT = b"""
     elseif redis.call("ttl", KEYS[1]) < 0 then
         return 2
     else
-        redis.call("expire", KEYS[1], ARGV[2])
+        redis.call("pexpire", KEYS[1], ARGV[2])
         return 0
     end
 """
@@ -110,7 +110,7 @@ class Lock(object):
         :param name:
             The name (redis key) the lock should have.
         :param expire:
-            The lock expiry time in seconds. If left at the default (None)
+            The lock expiry time in milliseconds. If left at the default (None)
             the lock will not expire.
         :param id:
             The ID (redis value) the lock should have. A random value is
@@ -223,7 +223,7 @@ class Lock(object):
         blpop_timeout = timeout or self._expire or 0
         timed_out = False
         while busy:
-            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
+            busy = not self._client.set(self._name, self._id, nx=True, px=self._expire)
             if busy:
                 if timed_out:
                     return False

And here's a sketch of supporting both with backwards compatibility:

diff --git a/src/redis_lock/__init__.py b/src/redis_lock/__init__.py
index d4d0caa..eb5ab6b 100644
--- a/src/redis_lock/__init__.py
+++ b/src/redis_lock/__init__.py
@@ -40,6 +40,16 @@ EXTEND_SCRIPT = b"""
         return 0
     end
 """
+PEXTEND_SCRIPT = b"""
+    if redis.call("get", KEYS[1]) ~= ARGV[1] then
+        return 1
+    elseif redis.call("ttl", KEYS[1]) < 0 then
+        return 2
+    else
+        redis.call("pexpire", KEYS[1], ARGV[2])
+        return 0
+    end
+"""
 
 RESET_SCRIPT = b"""
     redis.call('del', KEYS[2])
@@ -97,21 +107,24 @@ class Lock(object):
 
     unlock_script = None
     extend_script = None
+    pextend_script = None
     reset_script = None
     reset_all_script = None
 
     _lock_renewal_interval: float
     _lock_renewal_thread: Union[threading.Thread, None]
 
-    def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000):
+    def __init__(self, redis_client, name, expire=None, expire_in_milliseconds=False, id=None, auto_renewal=False, strict=True, signal_expire=1000):
         """
         :param redis_client:
             An instance of :class:`~StrictRedis`.
         :param name:
             The name (redis key) the lock should have.
         :param expire:
-            The lock expiry time in seconds. If left at the default (None)
-            the lock will not expire.
+            The lock expiry time in seconds (or milliseconds if ``expire_in_milliseconds`` is set to ``True``).
+            If left at the default (None) the lock will not expire.
+        :param expire_in_milliseconds:
+            If set to ``True`, the ``expire`` parameter will be interpreted in milliseconds instead of seconds.
         :param id:
             The ID (redis value) the lock should have. A random value is
             generated when left at the default.
@@ -146,6 +159,7 @@ class Lock(object):
         else:
             expire = None
         self._expire = expire
+        self._expire_in_milliseconds = bool(expire_in_milliseconds)
 
         self._signal_expire = signal_expire
         if id is None:
@@ -172,6 +186,7 @@ class Lock(object):
         if reset_all_script is None:
             cls.unlock_script = redis_client.register_script(UNLOCK_SCRIPT)
             cls.extend_script = redis_client.register_script(EXTEND_SCRIPT)
+            cls.pextend_script = redis_client.register_script(PEXTEND_SCRIPT)
             cls.reset_script = redis_client.register_script(RESET_SCRIPT)
             cls.reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
             reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
@@ -223,7 +238,10 @@ class Lock(object):
         blpop_timeout = timeout or self._expire or 0
         timed_out = False
         while busy:
-            busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
+            if self._expire_in_milliseconds:
+                busy = not self._client.set(self._name, self._id, nx=True, px=self._expire)
+            else:
+                busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
             if busy:
                 if timed_out:
                     return False
@@ -255,7 +273,11 @@ class Lock(object):
         else:
             raise TypeError("To extend a lock 'expire' must be provided as an argument to extend() method or at initialization time.")
 
-        error = self.extend_script(client=self._client, keys=(self._name, self._signal), args=(self._id, expire))
+        if self._expire_in_milliseconds:
+            error = self.pextend_script(client=self._client, keys=(self._name, self._signal), args=(self._id, expire))
+        else:
+            error = self.extend_script(client=self._client, keys=(self._name, self._signal), args=(self._id, expire))
+
         if error == 1:
             raise NotAcquired(f"Lock {self._name} is not acquired or it already expired.")
         elif error == 2:

I can make a PR if you think the approach looks good.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions