FIX: Harden DistributedMutex

FIX: Harden DistributedMutex

Threadsafety

Since we use the same redis connection in multiple threads, a rogue transaction in another thread can trample the connection state (watched keys) that we need to acquire and release the lock properly.

This is fixed by preventing other threads from using the connection when we are performing these actions.

Off-by-one error

A distributed mutex is now consistently determined to be expired if the current time is strictly greater than the expire time.

Unwatch before transaction

Since the redis connection is used by so much of the code, it is difficult to ensure that any watched keys have been cleared. In order to defend against this rogue connection state, an unwatch has been added before locking and unlocking.

Logging

Hopefully this log message is more clear.

diff --git a/lib/distributed_mutex.rb b/lib/distributed_mutex.rb
index d9a1879..f293a33 100644
--- a/lib/distributed_mutex.rb
+++ b/lib/distributed_mutex.rb
@@ -1,6 +1,8 @@
 # frozen_string_literal: true
 
 # Cross-process locking using Redis.
+#
+# Expiration happens when the current time is greater than the expire time
 class DistributedMutex
   DEFAULT_VALIDITY ||= 60
 
@@ -36,7 +38,7 @@ class DistributedMutex
         end
 
         if !unlock(expire_time) && current_time <= expire_time
-          warn("didn't unlock cleanly")
+          warn("the redis key appears to have been tampered with before expiration")
         end
       end
     end
@@ -79,40 +81,46 @@ class DistributedMutex
     now = redis.time[0]
     expire_time = now + validity
 
-    redis.watch key
+    redis.synchronize do
+      redis.unwatch
+      redis.watch key
 
-    current_expire_time = redis.get key
+      current_expire_time = redis.get key
 
-    if current_expire_time && current_expire_time.to_i > now
-      redis.unwatch
+      if current_expire_time && now <= current_expire_time.to_i
+        redis.unwatch
 
-      got_lock = false
-    else
-      result =
-        redis.multi do
-          redis.set key, expire_time.to_s
-          redis.expire key, validity
-        end
+        got_lock = false
+      else
+        result =
+          redis.multi do
+            redis.set key, expire_time.to_s
+            redis.expire key, validity
+          end
 
-      got_lock = !result.nil?
-    end
+        got_lock = !result.nil?
+      end
 
-    [got_lock, expire_time]
+      [got_lock, expire_time]
+    end
   end
 
   def unlock(expire_time)
-    redis.watch key
-    current_expire_time = redis.get key
-
-    if current_expire_time == expire_time.to_s
-      result =
-        redis.multi do
-          redis.del key
-        end
-      return !result.nil?
-    else
+    redis.synchronize do
       redis.unwatch
-      return false
+      redis.watch key
+      current_expire_time = redis.get key
+
+      if current_expire_time == expire_time.to_s
+        result =
+          redis.multi do
+            redis.del key
+          end
+        return !result.nil?
+      else
+        redis.unwatch
+        return false
+      end
     end
   end
 end

GitHub sha: 1fdba2c5

1 Like