From 84339be9a0f7d6b0ed2ee7f7091024d86c940dd9 Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 7 Jul 2025 12:49:34 +0330 Subject: [PATCH 01/12] Add CompositeDistributedSynchronizationHandle for managing multiple distributed synchronization handles --- ...mpositeDistributedSynchronizationHandle.cs | 489 ++++++++++++++++++ 1 file changed, 489 insertions(+) create mode 100644 src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs diff --git a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs new file mode 100644 index 00000000..723d2e16 --- /dev/null +++ b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs @@ -0,0 +1,489 @@ +using Medallion.Threading.Internal; + +namespace Medallion.Threading; + +internal sealed class CompositeDistributedSynchronizationHandle : IDistributedSynchronizationHandle +{ + private readonly IDistributedSynchronizationHandle[] _handles; + private readonly CancellationTokenSource? _linkedLostCts; + private bool _disposed; + + public CompositeDistributedSynchronizationHandle(IReadOnlyList handles) + { + ValidateHandles(handles); + this._handles = handles.ToArray(); + this._linkedLostCts = this.CreateLinkedCancellationTokenSource(); + } + + public CancellationToken HandleLostToken => this._linkedLostCts?.Token ?? CancellationToken.None; + + public void Dispose() + { + if (this._disposed) + { + return; + } + + this._disposed = true; + var errors = this.DisposeHandles(h => h.Dispose()); + this._linkedLostCts?.Dispose(); + ThrowAggregateExceptionIfNeeded(errors, "disposing"); + } + + public async ValueTask DisposeAsync() + { + if (this._disposed) + { + return; + } + + this._disposed = true; + var errors = await this.DisposeHandlesAsync(h => h.DisposeAsync()).ConfigureAwait(false); + this._linkedLostCts?.Dispose(); + ThrowAggregateExceptionIfNeeded(errors, "asynchronously disposing"); + } + + public static async ValueTask TryAcquireAllAsync( + TProvider provider, + Func> acquireFunc, + IReadOnlyList names, + TimeSpan timeout = default, + CancellationToken cancellationToken = default) + { + ValidateAcquireParameters(provider, acquireFunc, names); + + var timeoutTracker = new TimeoutTracker(timeout); + var handles = new List(names.Count); + + try + { + foreach (var name in names) + { + var handle = await acquireFunc(provider, name, timeoutTracker.Remaining, cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + return null; + } + + handles.Add(handle); + + if (timeoutTracker.IsExpired) + { + return null; + } + } + + var result = new CompositeDistributedSynchronizationHandle(handles); + handles.Clear(); + return result; + } + finally + { + await DisposeHandlesAsync(handles).ConfigureAwait(false); + } + } + + + public static async ValueTask AcquireAllAsync( + TProvider provider, + Func> acquireFunc, + IReadOnlyList names, + TimeSpan? timeout = null, + CancellationToken cancellationToken = default) + { + var effectiveTimeout = timeout ?? Timeout.InfiniteTimeSpan; + var handle = await TryAcquireAllAsync( + provider, + WrapAcquireFunc(acquireFunc), + names, + effectiveTimeout, + cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + throw new TimeoutException($"Timed out after {effectiveTimeout} while acquiring all locks."); + } + + return handle; + } + + public static IDistributedSynchronizationHandle? TryAcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + TimeSpan timeout = default, + CancellationToken cancellationToken = default) => + SyncViaAsync.Run( + state => TryAcquireAllAsync( + state.provider, + WrapSyncAcquireFunc(state.acquireFunc), + state.names, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, timeout, cancellationToken) + ); + + public static IDistributedSynchronizationHandle AcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + TimeSpan? timeout = null, + CancellationToken cancellationToken = default) => + SyncViaAsync.Run( + state => AcquireAllAsync( + state.provider, + WrapSyncAcquireFuncForRequired(state.acquireFunc), + state.names, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, timeout, cancellationToken) + ); + + public static async ValueTask TryAcquireAllAsync( + TProvider provider, + Func> + acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan timeout = default, + CancellationToken cancellationToken = default) + { + ValidateAcquireParameters(provider, acquireFunc, names); + + var timeoutTracker = new TimeoutTracker(timeout); + var handles = new List(names.Count); + + try + { + foreach (var name in names) + { + var handle = await acquireFunc(provider, name, maxCount, timeoutTracker.Remaining, cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + return null; + } + + handles.Add(handle); + + if (timeoutTracker.IsExpired) + { + return null; + } + } + + var result = new CompositeDistributedSynchronizationHandle(handles); + handles.Clear(); + return result; + } + finally + { + await DisposeHandlesAsync(handles).ConfigureAwait(false); + } + } + + + public static async ValueTask AcquireAllAsync( + TProvider provider, + Func> + acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan? timeout = null, + CancellationToken cancellationToken = default) + { + var effectiveTimeout = timeout ?? Timeout.InfiniteTimeSpan; + var handle = await TryAcquireAllAsync( + provider, + WrapAcquireFunc(acquireFunc), + names, + maxCount, + effectiveTimeout, + cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + throw new TimeoutException($"Timed out after {effectiveTimeout} while acquiring all locks."); + } + + return handle; + } + + public static IDistributedSynchronizationHandle? TryAcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan timeout = default, + CancellationToken cancellationToken = default) => + SyncViaAsync.Run( + state => TryAcquireAllAsync( + state.provider, + WrapSyncAcquireFunc(state.acquireFunc), + state.names, + state.maxCount, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, maxCount, timeout, cancellationToken) + ); + + public static IDistributedSynchronizationHandle AcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan? timeout = null, + CancellationToken cancellationToken = default) => + SyncViaAsync.Run( + state => AcquireAllAsync( + state.provider, + WrapSyncAcquireFuncForRequired(state.acquireFunc), + state.names, + state.maxCount, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, maxCount, timeout, cancellationToken) + ); + + private static void ValidateHandles(IReadOnlyList handles) + { + if (handles is null) + { + throw new ArgumentNullException(nameof(handles)); + } + + if (handles.Count == 0) + { + throw new ArgumentException("At least one handle is required", nameof(handles)); + } + + for (var i = 0; i < handles.Count; ++i) + { + if (handles[i] is null) + { + throw new ArgumentException( + $"Handles must not contain null elements; found null at index {i}", + nameof(handles) + ); + } + } + } + + private CancellationTokenSource? CreateLinkedCancellationTokenSource() + { + var cancellableTokens = this._handles + .Select(h => h.HandleLostToken) + .Where(t => t.CanBeCanceled) + .ToArray(); + + return cancellableTokens.Length > 0 + ? CancellationTokenSource.CreateLinkedTokenSource(cancellableTokens) + : null; + } + + private List? DisposeHandles(Action disposeAction) + { + List? errors = null; + + foreach (var handle in this._handles) + { + try + { + disposeAction(handle); + } + catch (Exception ex) + { + (errors ??= []).Add(ex); + } + } + + return errors; + } + + private async ValueTask?> DisposeHandlesAsync( + Func disposeAction) + { + List? errors = null; + + foreach (var handle in this._handles) + { + try + { + await disposeAction(handle).ConfigureAwait(false); + } + catch (Exception ex) + { + (errors ??= []).Add(ex); + } + } + + return errors; + } + + private static void ThrowAggregateExceptionIfNeeded(List? errors, string operation) + { + if (errors is not null && errors.Count > 0) + { + throw new AggregateException( + $"One or more errors occurred while {operation} a composite distributed handle.", errors); + } + } + + private static void ValidateAcquireParameters( + TProvider provider, + Func> acquireFunc, + IReadOnlyList names) + { + if (provider is null) + { + throw new ArgumentNullException(nameof(provider)); + } + + if (acquireFunc is null) + { + throw new ArgumentNullException(nameof(acquireFunc)); + } + + if (names is null) + { + throw new ArgumentNullException(nameof(names)); + } + + if (names.Count == 0) + { + throw new ArgumentException("At least one lock name is required.", nameof(names)); + } + + for (var i = 0; i < names.Count; ++i) + { + if (names[i] is null) + { + throw new ArgumentException( + $"Names must not contain null elements; found null at index {i}", + nameof(names) + ); + } + } + } + + private static void ValidateAcquireParameters( + TProvider provider, + Func> + acquireFunc, + IReadOnlyList names) + { + if (provider is null) + { + throw new ArgumentNullException(nameof(provider)); + } + + if (acquireFunc is null) + { + throw new ArgumentNullException(nameof(acquireFunc)); + } + + if (names is null) + { + throw new ArgumentNullException(nameof(names)); + } + + if (names.Count == 0) + { + throw new ArgumentException("At least one lock name is required.", nameof(names)); + } + + for (var i = 0; i < names.Count; ++i) + { + if (names[i] is null) + { + throw new ArgumentException( + $"Names must not contain null elements; found null at index {i}", + nameof(names) + ); + } + } + } + + private static async ValueTask DisposeHandlesAsync(List handles) + { + foreach (var handle in handles) + { + try + { + await handle.DisposeAsync().ConfigureAwait(false); + } + catch + { + // Suppress exceptions during cleanup + } + } + } + + private static Func> + WrapAcquireFunc( + Func> + acquireFunc) => + async (p, n, t, c) => await acquireFunc(p, n, t, c).ConfigureAwait(false); + + private static Func> + WrapSyncAcquireFunc( + Func acquireFunc) => + (p, n, t, c) => new ValueTask(acquireFunc(p, n, t, c)); + + private static Func> + WrapSyncAcquireFuncForRequired( + Func acquireFunc) => + (p, n, t, c) => + { + var handle = acquireFunc(p, n, t, c); + return handle is not null + ? new ValueTask(handle) + : throw new TimeoutException($"Failed to acquire lock for '{n}'"); + }; + + + private static Func> + WrapAcquireFunc( + Func> + acquireFunc) => + async (p, n, mc, t, c) => await acquireFunc(p, n, mc, t, c).ConfigureAwait(false); + + private static Func> + WrapSyncAcquireFunc( + Func + acquireFunc) => + (p, n, mc, t, c) => new ValueTask(acquireFunc(p, n, mc, t, c)); + + private static Func> + WrapSyncAcquireFuncForRequired( + Func + acquireFunc) => + (p, n, mc, t, c) => + { + var handle = acquireFunc(p, n, mc, t, c); + return handle is not null + ? new ValueTask(handle) + : throw new TimeoutException($"Failed to acquire lock for '{n}'"); + }; + + private sealed class TimeoutTracker(TimeSpan timeout) + { + private readonly System.Diagnostics.Stopwatch? _stopwatch = timeout == Timeout.InfiniteTimeSpan + ? null + : System.Diagnostics.Stopwatch.StartNew(); + + public TimeSpan Remaining => this._stopwatch is null + ? Timeout.InfiniteTimeSpan + : timeout - this._stopwatch.Elapsed; + + public bool IsExpired => this._stopwatch is not null && this._stopwatch.Elapsed >= timeout; + } +} \ No newline at end of file From 184ad1ae4133673186d0965ee924f3be5317a84d Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 7 Jul 2025 12:50:08 +0330 Subject: [PATCH 02/12] Change GenerateProviders to generate composite lock methods --- .../DistributedLockProviderExtensions.cs | 48 ++++++ ...butedReaderWriterLockProviderExtensions.cs | 88 ++++++++++ .../DistributedSemaphoreProviderExtensions.cs | 48 ++++++ ...eableReaderWriterLockProviderExtensions.cs | 10 ++ .../GenerateProviders.cs | 162 ++++++++++++++---- 5 files changed, 318 insertions(+), 38 deletions(-) diff --git a/src/DistributedLock.Core/DistributedLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedLockProviderExtensions.cs index ea4ad6f5..bc98c6bb 100644 --- a/src/DistributedLock.Core/DistributedLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedLockProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedLockProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -34,4 +36,50 @@ public static IDistributedSynchronizationHandle AcquireLock(this IDistributedLoc /// public static ValueTask AcquireLockAsync(this IDistributedLockProvider provider, string name, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateLock(name).AcquireAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllLocks(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, t, c) => p.TryAcquireLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllLocks(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, t, c) => p.AcquireLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllLocksAsync(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, t, c) => p.TryAcquireLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllLocksAsync(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, t, c) => p.AcquireLockAsync(n, t, c), + names, timeout, cancellationToken); + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs index 5ac66fac..da5f407c 100644 --- a/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedReaderWriterLockProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -62,4 +64,90 @@ public static IDistributedSynchronizationHandle AcquireWriteLock(this IDistribut /// public static ValueTask AcquireWriteLockAsync(this IDistributedReaderWriterLockProvider provider, string name, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateReaderWriterLock(name).AcquireWriteLockAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllReadLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, t, c) => p.TryAcquireReadLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllReadLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, t, c) => p.AcquireReadLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllReadLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, t, c) => p.TryAcquireReadLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllReadLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, t, c) => p.AcquireReadLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllWriteLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, t, c) => p.TryAcquireWriteLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllWriteLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, t, c) => p.AcquireWriteLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllWriteLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, t, c) => p.TryAcquireWriteLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllWriteLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, t, c) => p.AcquireWriteLockAsync(n, t, c), + names, timeout, cancellationToken); + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs b/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs index b808b004..3ed29ebe 100644 --- a/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedSemaphoreProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -34,4 +36,50 @@ public static IDistributedSynchronizationHandle AcquireSemaphore(this IDistribut /// public static ValueTask AcquireSemaphoreAsync(this IDistributedSemaphoreProvider provider, string name, int maxCount, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateSemaphore(name, maxCount).AcquireAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllSemaphores(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, mc, t, c) => p.TryAcquireSemaphore(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllSemaphores(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, mc, t, c) => p.AcquireSemaphore(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllSemaphoresAsync(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, mc, t, c) => p.TryAcquireSemaphoreAsync(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllSemaphoresAsync(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, mc, t, c) => p.AcquireSemaphoreAsync(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs index 4b8c51f9..2bc1c429 100644 --- a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedUpgradeableReaderWriterLockProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -34,4 +36,12 @@ public static IDistributedLockUpgradeableHandle AcquireUpgradeableReadLock(this /// public static ValueTask AcquireUpgradeableReadLockAsync(this IDistributedUpgradeableReaderWriterLockProvider provider, string name, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateUpgradeableReaderWriterLock(name).AcquireUpgradeableReadLockAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + +// Composite methods are not supported for IDistributedUpgradeableReaderWriterLock + + # endregion } \ No newline at end of file diff --git a/src/DistributedLockCodeGen/GenerateProviders.cs b/src/DistributedLockCodeGen/GenerateProviders.cs index ab657b14..80c80747 100644 --- a/src/DistributedLockCodeGen/GenerateProviders.cs +++ b/src/DistributedLockCodeGen/GenerateProviders.cs @@ -10,64 +10,112 @@ namespace DistributedLockCodeGen; [Category("CI")] public class GenerateProviders { - public static readonly IReadOnlyList Interfaces = new[] - { + public static readonly IReadOnlyList Interfaces = + [ "IDistributedLock", "IDistributedReaderWriterLock", "IDistributedUpgradeableReaderWriterLock", "IDistributedSemaphore" - }; + ]; + + private static readonly IReadOnlyList ExcludedInterfacesForCompositeMethods = + [ + "IDistributedUpgradeableReaderWriterLock" + ]; [TestCaseSource(nameof(Interfaces))] public void GenerateProviderInterfaceAndExtensions(string interfaceName) { - var interfaceFile = Directory.GetFiles(CodeGenHelpers.SolutionDirectory, interfaceName + ".cs", SearchOption.AllDirectories) + var interfaceFile = Directory + .GetFiles(CodeGenHelpers.SolutionDirectory, interfaceName + ".cs", SearchOption.AllDirectories) .Single(); var providerInterfaceName = interfaceName + "Provider"; var createMethodName = $"Create{interfaceName.Replace("IDistributed", string.Empty)}"; - var providerInterfaceCode = $@"// AUTO-GENERATED -namespace Medallion.Threading; - -/// -/// Acts as a factory for instances of a certain type. This interface may be -/// easier to use than in dependency injection scenarios. -/// -public interface {providerInterfaceName}{(interfaceName == "IDistributedUpgradeableReaderWriterLock" ? ": IDistributedReaderWriterLockProvider" : string.Empty)} -{{ - /// - /// Constructs an instance with the given . - /// - {interfaceName} {createMethodName}(string name{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}); -}}"; + var providerInterfaceCode = $$""" + // AUTO-GENERATED + namespace Medallion.Threading; + + /// + /// Acts as a factory for instances of a certain type. This interface may be + /// easier to use than in dependency injection scenarios. + /// + public interface {{providerInterfaceName}}{{(interfaceName == "IDistributedUpgradeableReaderWriterLock" ? ": IDistributedReaderWriterLockProvider" : string.Empty)}} + { + /// + /// Constructs an instance with the given . + /// + {{interfaceName}} {{createMethodName}}(string name{{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}}); + } + """; var interfaceMethods = Regex.Matches( File.ReadAllText(interfaceFile), @"(?\S+) (?\S+)\((?((?\S*) (?\w+)[^,)]*(\, )?)*)\);", RegexOptions.ExplicitCapture ); - var extensionMethodBodies = interfaceMethods.Cast() + + var extensionSingleMethodBodies = interfaceMethods .Select(m => -$@" /// - /// Equivalent to calling and then - /// ().Select(c => c.Value))})"" />. - /// - public static {m.Groups["returnType"].Value} {GetExtensionMethodName(m.Groups["name"].Value)}(this {providerInterfaceName} provider, string name{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}, {m.Groups["parameters"].Value}) => - (provider ?? throw new ArgumentNullException(nameof(provider))).{createMethodName}(name{(interfaceName.Contains("Semaphore") ? ", maxCount" : string.Empty)}).{m.Groups["name"].Value}({string.Join(", ", m.Groups["parameterName"].Captures.Cast().Select(c => c.Value))});" + $""" + /// + /// Equivalent to calling and then + /// c.Value))})" />. + /// + public static {m.Groups["returnType"].Value} {GetExtensionMethodName(m.Groups["name"].Value)}(this {providerInterfaceName} provider, string name{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}, {m.Groups["parameters"].Value}) => + (provider ?? throw new ArgumentNullException(nameof(provider))).{createMethodName}(name{(interfaceName.Contains("Semaphore") ? ", maxCount" : string.Empty)}).{m.Groups["name"].Value}({string.Join(", ", m.Groups["parameterName"].Captures.Select(c => c.Value))}); + """ ); + var extensionCompositeMethodBodies = ExcludedInterfacesForCompositeMethods.Contains(interfaceName) + ? + [ + $"// Composite methods are not supported for {interfaceName}" + ] + : interfaceMethods + .Select(m => + { + var (extensionMethodName, innerCallName) = GetAllExtensionMethodName(m.Groups["name"].Value); + var isSemaphore = interfaceName.Contains("Semaphore"); + + return $""" + /// + /// Equivalent to calling for each name in and then + /// c.Value))})" /> on each created instance, combining the results into a composite handle. + /// + public static {m.Groups["returnType"].Value} {extensionMethodName}(this {providerInterfaceName} provider, IReadOnlyList names{(isSemaphore ? ", int maxCount" : string.Empty)}, {m.Groups["parameters"].Value}) => + CompositeDistributedSynchronizationHandle.{innerCallName}( + provider, + static (p, n{(isSemaphore ? ", mc" : string.Empty)}, t, c) => p.{GetExtensionMethodName(m.Groups["name"].Value)}(n{(isSemaphore ? ", mc" : string.Empty)}, t, c), + names,{(isSemaphore ? " maxCount," : string.Empty)} timeout, cancellationToken); + """; + } + ); + var providerExtensionsName = providerInterfaceName.TrimStart('I') + "Extensions"; - var providerExtensionsCode = $@"// AUTO-GENERATED + var providerExtensionsCode = $$""" + // AUTO-GENERATED + + namespace Medallion.Threading; + + /// + /// Productivity helper methods for + /// + public static class {{providerExtensionsName}} + { + # region Single Lock Methods + + {{string.Join(Environment.NewLine + Environment.NewLine, extensionSingleMethodBodies)}} + + # endregion + + # region Composite Lock Methods -namespace Medallion.Threading; + {{string.Join(Environment.NewLine + Environment.NewLine, extensionCompositeMethodBodies)}} -/// -/// Productivity helper methods for -/// -public static class {providerExtensionsName} -{{ -{string.Join(Environment.NewLine + Environment.NewLine, extensionMethodBodies)} -}}"; + # endregion + } + """; var changes = new[] { @@ -76,7 +124,8 @@ public static class {providerExtensionsName} } .Select(t => (file: Path.Combine(Path.GetDirectoryName(interfaceFile)!, t.name + ".cs"), t.code)) .Select(t => (t.file, t.code, originalCode: File.Exists(t.file) ? File.ReadAllText(t.file) : string.Empty)) - .Where(t => CodeGenHelpers.NormalizeCodeWhitespace(t.code) != CodeGenHelpers.NormalizeCodeWhitespace(t.originalCode)) + .Where(t => CodeGenHelpers.NormalizeCodeWhitespace(t.code) != + CodeGenHelpers.NormalizeCodeWhitespace(t.originalCode)) .ToList(); changes.ForEach(t => File.WriteAllText(t.file, t.code)); Assert.That(changes.Select(t => t.file), Is.Empty); @@ -85,8 +134,45 @@ string GetExtensionMethodName(string interfaceMethodName) => Regex.IsMatch(interfaceMethodName, "^(Try)?Acquire(Async)?$") // make it more specific to differentiate when one concrete provider implements multiple provider interfaces ? interfaceMethodName.Replace("Async", string.Empty) - + interfaceName.Replace("IDistributed", string.Empty) - + (interfaceMethodName.EndsWith("Async") ? "Async" : string.Empty) + + interfaceName.Replace("IDistributed", string.Empty) + + (interfaceMethodName.EndsWith("Async") ? "Async" : string.Empty) : interfaceMethodName; + + (string extensionMethodName, string innerCallName) GetAllExtensionMethodName(string interfaceMethodName) + { + var isExactAcquire = Regex.IsMatch(interfaceMethodName, "^(Try)?Acquire(Async)?$"); + var isAsync = interfaceMethodName.EndsWith("Async", StringComparison.Ordinal); + var isTryVariant = interfaceMethodName.StartsWith("Try", StringComparison.Ordinal); + + string extensionMethodName; + + if (!isExactAcquire) + { + // e.g. TryAcquireReadLock -> TryAcquireAllReadLocks + // TryAcquireSemaphore -> TryAcquireAllSemaphores + // TryAcquireUpgradeableReadLockAsync -> TryAcquireUpgradeableAllReadLockAsync + extensionMethodName = interfaceMethodName + .Replace("Acquire", "AcquireAll") // Acquire -> AcquireAll + .Replace("Async", string.Empty) // strip Async (add back later) + + "s" // pluralise + + (isAsync ? "Async" : string.Empty); // restore Async if needed + } + else + { + // e.g. TryAcquire -> TryAcquireAllLocks + // AcquireAsync -> AcquireAllLocksAsync + extensionMethodName = interfaceMethodName.Replace("Async", string.Empty) + + "All" + + interfaceName.Replace("IDistributed", string.Empty) + "s" + + (isAsync ? "Async" : string.Empty); + } + + // - “Try…” methods -> TryAcquireAll[Async] + // - plain Acquire… -> AcquireAll[Async] + var innerCallName = (isTryVariant ? "TryAcquireAll" : "AcquireAll") + + (isAsync ? "Async" : string.Empty); + + return (extensionMethodName, innerCallName); + } } -} +} \ No newline at end of file From db7554b3752450573e92cd998eca656a722bc9d0 Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 14:46:40 +0330 Subject: [PATCH 03/12] Refactor composite lock acquisition --- ...mpositeDistributedSynchronizationHandle.cs | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs index 723d2e16..35a30855 100644 --- a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs +++ b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs @@ -54,6 +54,7 @@ public async ValueTask DisposeAsync() var timeoutTracker = new TimeoutTracker(timeout); var handles = new List(names.Count); + IDistributedSynchronizationHandle? result = null; try { @@ -64,25 +65,28 @@ public async ValueTask DisposeAsync() if (handle is null) { - return null; + break; } handles.Add(handle); if (timeoutTracker.IsExpired) { - return null; + break; } } - var result = new CompositeDistributedSynchronizationHandle(handles); - handles.Clear(); - return result; + result = new CompositeDistributedSynchronizationHandle(handles); } finally { - await DisposeHandlesAsync(handles).ConfigureAwait(false); + if (result is null) + { + await DisposeHandlesAsync(handles).ConfigureAwait(false); + } } + + return result; } @@ -155,6 +159,7 @@ public static IDistributedSynchronizationHandle AcquireAll( var timeoutTracker = new TimeoutTracker(timeout); var handles = new List(names.Count); + IDistributedSynchronizationHandle? result = null; try { @@ -165,25 +170,28 @@ public static IDistributedSynchronizationHandle AcquireAll( if (handle is null) { - return null; + break; } handles.Add(handle); if (timeoutTracker.IsExpired) { - return null; + break; } } - var result = new CompositeDistributedSynchronizationHandle(handles); - handles.Clear(); - return result; + result = new CompositeDistributedSynchronizationHandle(handles); } finally { - await DisposeHandlesAsync(handles).ConfigureAwait(false); + if (result is null) + { + await DisposeHandlesAsync(handles).ConfigureAwait(false); + } } + + return result; } From 1eea667934d1dce687b41961af1be2e04c3289db Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 14:50:10 +0330 Subject: [PATCH 04/12] Remove default parameter values --- ...mpositeDistributedSynchronizationHandle.cs | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs index 35a30855..4cf561cb 100644 --- a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs +++ b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs @@ -47,8 +47,8 @@ public async ValueTask DisposeAsync() TProvider provider, Func> acquireFunc, IReadOnlyList names, - TimeSpan timeout = default, - CancellationToken cancellationToken = default) + TimeSpan timeout, + CancellationToken cancellationToken) { ValidateAcquireParameters(provider, acquireFunc, names); @@ -89,13 +89,12 @@ public async ValueTask DisposeAsync() return result; } - public static async ValueTask AcquireAllAsync( TProvider provider, Func> acquireFunc, IReadOnlyList names, - TimeSpan? timeout = null, - CancellationToken cancellationToken = default) + TimeSpan? timeout, + CancellationToken cancellationToken) { var effectiveTimeout = timeout ?? Timeout.InfiniteTimeSpan; var handle = await TryAcquireAllAsync( @@ -118,8 +117,8 @@ public static async ValueTask AcquireAllAsync TProvider provider, Func acquireFunc, IReadOnlyList names, - TimeSpan timeout = default, - CancellationToken cancellationToken = default) => + TimeSpan timeout, + CancellationToken cancellationToken) => SyncViaAsync.Run( state => TryAcquireAllAsync( state.provider, @@ -134,8 +133,8 @@ public static IDistributedSynchronizationHandle AcquireAll( TProvider provider, Func acquireFunc, IReadOnlyList names, - TimeSpan? timeout = null, - CancellationToken cancellationToken = default) => + TimeSpan? timeout, + CancellationToken cancellationToken) => SyncViaAsync.Run( state => AcquireAllAsync( state.provider, @@ -152,8 +151,8 @@ public static IDistributedSynchronizationHandle AcquireAll( acquireFunc, IReadOnlyList names, int maxCount, - TimeSpan timeout = default, - CancellationToken cancellationToken = default) + TimeSpan timeout, + CancellationToken cancellationToken) { ValidateAcquireParameters(provider, acquireFunc, names); @@ -201,8 +200,8 @@ public static async ValueTask AcquireAllAsync acquireFunc, IReadOnlyList names, int maxCount, - TimeSpan? timeout = null, - CancellationToken cancellationToken = default) + TimeSpan? timeout, + CancellationToken cancellationToken) { var effectiveTimeout = timeout ?? Timeout.InfiniteTimeSpan; var handle = await TryAcquireAllAsync( @@ -227,8 +226,8 @@ public static async ValueTask AcquireAllAsync Func acquireFunc, IReadOnlyList names, int maxCount, - TimeSpan timeout = default, - CancellationToken cancellationToken = default) => + TimeSpan timeout, + CancellationToken cancellationToken) => SyncViaAsync.Run( state => TryAcquireAllAsync( state.provider, @@ -245,8 +244,8 @@ public static IDistributedSynchronizationHandle AcquireAll( Func acquireFunc, IReadOnlyList names, int maxCount, - TimeSpan? timeout = null, - CancellationToken cancellationToken = default) => + TimeSpan? timeout, + CancellationToken cancellationToken) => SyncViaAsync.Run( state => AcquireAllAsync( state.provider, From ac6d8a9a5e859cf97a0e1961e99ef24c0bb89b18 Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 14:53:54 +0330 Subject: [PATCH 05/12] Refactor TimeoutTracker to use TimeoutValue and make it struct --- .../CompositeDistributedSynchronizationHandle.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs index 4cf561cb..749f865c 100644 --- a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs +++ b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs @@ -52,7 +52,7 @@ public async ValueTask DisposeAsync() { ValidateAcquireParameters(provider, acquireFunc, names); - var timeoutTracker = new TimeoutTracker(timeout); + var timeoutTracker = new TimeoutTracker(new TimeoutValue(timeout)); var handles = new List(names.Count); IDistributedSynchronizationHandle? result = null; @@ -156,7 +156,7 @@ public static IDistributedSynchronizationHandle AcquireAll( { ValidateAcquireParameters(provider, acquireFunc, names); - var timeoutTracker = new TimeoutTracker(timeout); + var timeoutTracker = new TimeoutTracker(new TimeoutValue(timeout)); var handles = new List(names.Count); IDistributedSynchronizationHandle? result = null; @@ -481,16 +481,16 @@ private static Func this._stopwatch is null ? Timeout.InfiniteTimeSpan - : timeout - this._stopwatch.Elapsed; + : timeout.TimeSpan - this._stopwatch.Elapsed; - public bool IsExpired => this._stopwatch is not null && this._stopwatch.Elapsed >= timeout; + public bool IsExpired => this._stopwatch is not null && this._stopwatch.Elapsed >= timeout.TimeSpan; } } \ No newline at end of file From 33967ffd1cee4bd0c0a0bb44995f80f9689fc4c8 Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 15:01:03 +0330 Subject: [PATCH 06/12] Fix TimeoutTracker.Remaining to return zero instead of negative TimeSpan --- .../CompositeDistributedSynchronizationHandle.cs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs index 749f865c..6ce502b4 100644 --- a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs +++ b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs @@ -487,9 +487,12 @@ private readonly struct TimeoutTracker(TimeoutValue timeout) ? null : System.Diagnostics.Stopwatch.StartNew(); - public TimeSpan Remaining => this._stopwatch is null - ? Timeout.InfiniteTimeSpan - : timeout.TimeSpan - this._stopwatch.Elapsed; + public TimeSpan Remaining => + this._stopwatch is { Elapsed: var elapsed } + ? elapsed >= timeout.TimeSpan + ? TimeSpan.Zero + : timeout.TimeSpan - elapsed + : Timeout.InfiniteTimeSpan; public bool IsExpired => this._stopwatch is not null && this._stopwatch.Elapsed >= timeout.TimeSpan; } From e1424b60195fff25d44ec944cad20dd433d10926 Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 15:50:00 +0330 Subject: [PATCH 07/12] Check `handles.Count` before assigning the result --- .../CompositeDistributedSynchronizationHandle.cs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs index 6ce502b4..7abc1a9f 100644 --- a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs +++ b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs @@ -76,7 +76,10 @@ public async ValueTask DisposeAsync() } } - result = new CompositeDistributedSynchronizationHandle(handles); + if (handles.Count == names.Count) + { + result = new CompositeDistributedSynchronizationHandle(handles); + } } finally { @@ -180,7 +183,10 @@ public static IDistributedSynchronizationHandle AcquireAll( } } - result = new CompositeDistributedSynchronizationHandle(handles); + if (handles.Count == names.Count) + { + result = new CompositeDistributedSynchronizationHandle(handles); + } } finally { From c99758e376707d6040938196666554cf39a55a91 Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 15:50:15 +0330 Subject: [PATCH 08/12] Add composite lock tests --- .../DistributedLockProviderExtensionsTest.cs | 284 +++++++++++++++++- 1 file changed, 283 insertions(+), 1 deletion(-) diff --git a/src/DistributedLock.Tests/Tests/Core/DistributedLockProviderExtensionsTest.cs b/src/DistributedLock.Tests/Tests/Core/DistributedLockProviderExtensionsTest.cs index 074e82d5..9ca1a407 100644 --- a/src/DistributedLock.Tests/Tests/Core/DistributedLockProviderExtensionsTest.cs +++ b/src/DistributedLock.Tests/Tests/Core/DistributedLockProviderExtensionsTest.cs @@ -49,7 +49,7 @@ public void TestCallThrough([Values] bool isTry, [Values] bool isAsync) } void Test( - Expression> providerFunction, + Expression> providerFunction, Expression> lockFunction) { providerFunction.Compile()(mockProvider.Object); @@ -58,4 +58,286 @@ void Test( mockLock.Verify(lockFunction, Times.Once()); } } + + [Test, Combinatorial] + public void TestCompositeArgumentValidation([Values] bool isAsync) + { + var mockProvider = new Mock(); + var validNames = new List { "lock1", "lock2" }; + + if (isAsync) + { + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.TryAcquireAllLocksAsync(null!, validNames).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.TryAcquireAllLocksAsync(mockProvider.Object, null!).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.TryAcquireAllLocksAsync(mockProvider.Object, new List()).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.TryAcquireAllLocksAsync(mockProvider.Object, new List { "lock1", null! }).AsTask()); + + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.AcquireAllLocksAsync(null!, validNames).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.AcquireAllLocksAsync(mockProvider.Object, null!).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.AcquireAllLocksAsync(mockProvider.Object, new List()).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.AcquireAllLocksAsync(mockProvider.Object, new List { "lock1", null! }).AsTask()); + } + else + { + Assert.Throws(() => DistributedLockProviderExtensions.TryAcquireAllLocks(null!, validNames)); + Assert.Throws(() => DistributedLockProviderExtensions.TryAcquireAllLocks(mockProvider.Object, null!)); + Assert.Throws(() => DistributedLockProviderExtensions.TryAcquireAllLocks(mockProvider.Object, new List())); + Assert.Throws(() => DistributedLockProviderExtensions.TryAcquireAllLocks(mockProvider.Object, new List { "lock1", null! })); + + Assert.Throws(() => DistributedLockProviderExtensions.AcquireAllLocks(null!, validNames)); + Assert.Throws(() => DistributedLockProviderExtensions.AcquireAllLocks(mockProvider.Object, null!)); + Assert.Throws(() => DistributedLockProviderExtensions.AcquireAllLocks(mockProvider.Object, new List())); + Assert.Throws(() => DistributedLockProviderExtensions.AcquireAllLocks(mockProvider.Object, new List { "lock1", null! })); + } + } + + [Test, Combinatorial] + public async Task TestCompositePartialAcquisitionFailure([Values] bool isAsync) + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + var mockLockB = new Mock(); + var mockHandleA = new Mock(); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + mockProvider.Setup(p => p.CreateLock("B")).Returns(mockLockB.Object); + + if (isAsync) + { + mockLockA.Setup(l => l.TryAcquireAsync(TimeSpan.Zero, default)) + .ReturnsAsync(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync((IDistributedSynchronizationHandle?)null); + } + else + { + mockLockA.Setup(l => l.TryAcquire(TimeSpan.Zero, default)) + .Returns(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquire(It.IsAny(), default)) + .Returns((IDistributedSynchronizationHandle?)null); + } + + var names = new List { "A", "B" }; + IDistributedSynchronizationHandle? result; + + if (isAsync) + { + result = await mockProvider.Object.TryAcquireAllLocksAsync(names, TimeSpan.Zero, default); + } + else + { + result = mockProvider.Object.TryAcquireAllLocks(names, TimeSpan.Zero, default); + } + + Assert.That(result, Is.Null); + + mockHandleA.Verify(h => h.DisposeAsync(), Times.Once); + } + + [Test, Combinatorial] + public async Task TestCompositeSuccessfulAcquisition([Values] bool isTry, [Values] bool isAsync) + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + var mockLockB = new Mock(); + var mockHandleA = new Mock(); + var mockHandleB = new Mock(); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + mockProvider.Setup(p => p.CreateLock("B")).Returns(mockLockB.Object); + + if (isAsync) + { + if (isTry) + { + mockLockA.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleB.Object); + } + else + { + mockLockA.Setup(l => l.AcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleA.Object); + mockLockB.Setup(l => l.AcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleB.Object); + } + } + else + { + if (isTry) + { + mockLockA.Setup(l => l.TryAcquire(It.IsAny(), default)) + .Returns(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquire(It.IsAny(), default)) + .Returns(mockHandleB.Object); + } + else + { + mockLockA.Setup(l => l.Acquire(It.IsAny(), default)) + .Returns(mockHandleA.Object); + mockLockB.Setup(l => l.Acquire(It.IsAny(), default)) + .Returns(mockHandleB.Object); + } + } + + var names = new List { "A", "B" }; + IDistributedSynchronizationHandle? result; + + if (isAsync) + { + if (isTry) + { + result = await mockProvider.Object.TryAcquireAllLocksAsync(names, TimeSpan.FromSeconds(10), default); + } + else + { + result = await mockProvider.Object.AcquireAllLocksAsync(names, TimeSpan.FromSeconds(10), default); + } + } + else + { + if (isTry) + { + result = mockProvider.Object.TryAcquireAllLocks(names, TimeSpan.FromSeconds(10), default); + } + else + { + result = mockProvider.Object.AcquireAllLocks(names, TimeSpan.FromSeconds(10), default); + } + } + + Assert.That(result, Is.Not.Null); + + mockProvider.Verify(p => p.CreateLock("A"), Times.Once); + mockProvider.Verify(p => p.CreateLock("B"), Times.Once); + + if (isAsync) + { + await result!.DisposeAsync(); + mockHandleA.Verify(h => h.DisposeAsync(), Times.Once); + mockHandleB.Verify(h => h.DisposeAsync(), Times.Once); + } + else + { + result!.Dispose(); + mockHandleA.Verify(h => h.Dispose(), Times.Once); + mockHandleB.Verify(h => h.Dispose(), Times.Once); + } + } + + [Test, Combinatorial] + public void TestCompositeAcquireThrowsOnTimeout([Values] bool isAsync) + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + + if (isAsync) + { + mockLockA.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync((IDistributedSynchronizationHandle?)null); + } + else + { + mockLockA.Setup(l => l.TryAcquire(It.IsAny(), default)) + .Returns((IDistributedSynchronizationHandle?)null); + } + + var names = new List { "A" }; + + if (isAsync) + { + Assert.ThrowsAsync(() => mockProvider.Object.AcquireAllLocksAsync(names, TimeSpan.Zero, default).AsTask()); + } + else + { + Assert.Throws(() => mockProvider.Object.AcquireAllLocks(names, TimeSpan.Zero, default)); + } + } + + [Test] + public async Task TestCompositeRemainingTimeDistribution() + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + var mockLockB = new Mock(); + var mockHandleA = new Mock(); + var mockHandleB = new Mock(); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + mockProvider.Setup(p => p.CreateLock("B")).Returns(mockLockB.Object); + + var capturedTimeouts = new List(); + + mockLockA.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync((TimeSpan timeout, CancellationToken _) => + { + capturedTimeouts.Add(timeout); + Task.Delay(TimeSpan.FromMilliseconds(100)).Wait(); + return mockHandleA.Object; + }); + + mockLockB.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync((TimeSpan timeout, CancellationToken _) => + { + capturedTimeouts.Add(timeout); + return mockHandleB.Object; + }); + + var names = new List { "A", "B" }; + var totalTimeout = TimeSpan.FromSeconds(5); + + var result = await mockProvider.Object.TryAcquireAllLocksAsync(names, totalTimeout, default); + + Assert.That(result, Is.Not.Null); + Assert.That(capturedTimeouts.Count, Is.EqualTo(2)); + + Assert.That(capturedTimeouts[0].TotalMilliseconds, Is.EqualTo(totalTimeout.TotalMilliseconds).Within(1.0)); + + Assert.That(capturedTimeouts[1], Is.LessThan(totalTimeout)); + Assert.That(capturedTimeouts[1], Is.GreaterThanOrEqualTo(TimeSpan.Zero)); + + await result!.DisposeAsync(); + } + + [Test] + public async Task TestCompositeHandleLostToken() + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + var mockLockB = new Mock(); + + var ctsA = new CancellationTokenSource(); + var ctsB = new CancellationTokenSource(); + + var mockHandleA = new Mock(); + var mockHandleB = new Mock(); + + mockHandleA.Setup(h => h.HandleLostToken).Returns(ctsA.Token); + mockHandleB.Setup(h => h.HandleLostToken).Returns(ctsB.Token); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + mockProvider.Setup(p => p.CreateLock("B")).Returns(mockLockB.Object); + + mockLockA.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleB.Object); + + var names = new List { "A", "B" }; + var result = await mockProvider.Object.TryAcquireAllLocksAsync(names, TimeSpan.FromSeconds(10), default); + + Assert.That(result, Is.Not.Null); + Assert.That(result!.HandleLostToken.CanBeCanceled, Is.True); + + ctsA.Cancel(); + + Assert.That(result.HandleLostToken.IsCancellationRequested, Is.True); + + await result.DisposeAsync(); + ctsA.Dispose(); + ctsB.Dispose(); + } } From 8f316e3e3edcb782258dfe42f5b06fce60d11231 Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 16:14:11 +0330 Subject: [PATCH 09/12] Add new Public APIs to --- .../PublicAPI.Unshipped.txt | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/DistributedLock.Core/PublicAPI.Unshipped.txt b/src/DistributedLock.Core/PublicAPI.Unshipped.txt index 5f282702..4b44c194 100644 --- a/src/DistributedLock.Core/PublicAPI.Unshipped.txt +++ b/src/DistributedLock.Core/PublicAPI.Unshipped.txt @@ -1 +1,17 @@ - \ No newline at end of file +#nullable enable +static Medallion.Threading.DistributedLockProviderExtensions.AcquireAllLocks(this Medallion.Threading.IDistributedLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle! +static Medallion.Threading.DistributedLockProviderExtensions.AcquireAllLocksAsync(this Medallion.Threading.IDistributedLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedLockProviderExtensions.TryAcquireAllLocks(this Medallion.Threading.IDistributedLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle? +static Medallion.Threading.DistributedLockProviderExtensions.TryAcquireAllLocksAsync(this Medallion.Threading.IDistributedLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.AcquireAllReadLocks(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle! +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.AcquireAllReadLocksAsync(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.AcquireAllWriteLocks(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle! +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.AcquireAllWriteLocksAsync(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.TryAcquireAllReadLocks(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle? +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.TryAcquireAllReadLocksAsync(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.TryAcquireAllWriteLocks(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle? +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.TryAcquireAllWriteLocksAsync(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedSemaphoreProviderExtensions.AcquireAllSemaphores(this Medallion.Threading.IDistributedSemaphoreProvider! provider, System.Collections.Generic.IReadOnlyList! names, int maxCount, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle! +static Medallion.Threading.DistributedSemaphoreProviderExtensions.AcquireAllSemaphoresAsync(this Medallion.Threading.IDistributedSemaphoreProvider! provider, System.Collections.Generic.IReadOnlyList! names, int maxCount, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedSemaphoreProviderExtensions.TryAcquireAllSemaphores(this Medallion.Threading.IDistributedSemaphoreProvider! provider, System.Collections.Generic.IReadOnlyList! names, int maxCount, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle? +static Medallion.Threading.DistributedSemaphoreProviderExtensions.TryAcquireAllSemaphoresAsync(this Medallion.Threading.IDistributedSemaphoreProvider! provider, System.Collections.Generic.IReadOnlyList! names, int maxCount, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask From 246b325e5a13e3c4d58947a10690dd6e96542ff7 Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 16:26:47 +0330 Subject: [PATCH 10/12] Comment why composite methods are not supported for `ExcludedInterfacesForCompositeMethods` --- ...stributedUpgradeableReaderWriterLockProviderExtensions.cs | 3 ++- src/DistributedLockCodeGen/GenerateProviders.cs | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs index 2bc1c429..22884c62 100644 --- a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs @@ -41,7 +41,8 @@ public static ValueTask AcquireUpgradeableRea # region Composite Lock Methods -// Composite methods are not supported for IDistributedUpgradeableReaderWriterLock + // Composite methods are not supported for IDistributedUpgradeableReaderWriterLock + // because a composite acquire operation must be able to roll back and upgrade does not support that. # endregion } \ No newline at end of file diff --git a/src/DistributedLockCodeGen/GenerateProviders.cs b/src/DistributedLockCodeGen/GenerateProviders.cs index 80c80747..3cbd5ff4 100644 --- a/src/DistributedLockCodeGen/GenerateProviders.cs +++ b/src/DistributedLockCodeGen/GenerateProviders.cs @@ -70,7 +70,10 @@ public interface {{providerInterfaceName}}{{(interfaceName == "IDistributedUpgra var extensionCompositeMethodBodies = ExcludedInterfacesForCompositeMethods.Contains(interfaceName) ? [ - $"// Composite methods are not supported for {interfaceName}" + $""" + // Composite methods are not supported for {interfaceName} + // because a composite acquire operation must be able to roll back and upgrade does not support that. + """ ] : interfaceMethods .Select(m => From 848f215477647eec0f22515b116c0b199b3999d5 Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 23:28:43 +0330 Subject: [PATCH 11/12] Define reason for each interface in --- src/DistributedLockCodeGen/GenerateProviders.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/DistributedLockCodeGen/GenerateProviders.cs b/src/DistributedLockCodeGen/GenerateProviders.cs index 3cbd5ff4..d00aa8c8 100644 --- a/src/DistributedLockCodeGen/GenerateProviders.cs +++ b/src/DistributedLockCodeGen/GenerateProviders.cs @@ -18,10 +18,10 @@ public class GenerateProviders "IDistributedSemaphore" ]; - private static readonly IReadOnlyList ExcludedInterfacesForCompositeMethods = - [ - "IDistributedUpgradeableReaderWriterLock" - ]; + private static readonly IReadOnlyDictionary ExcludedInterfacesForCompositeMethods = new Dictionary + { + ["IDistributedUpgradeableReaderWriterLock"] = "a composite acquire operation must be able to roll back and upgrade does not support that." + }; [TestCaseSource(nameof(Interfaces))] public void GenerateProviderInterfaceAndExtensions(string interfaceName) @@ -67,12 +67,12 @@ public interface {{providerInterfaceName}}{{(interfaceName == "IDistributedUpgra """ ); - var extensionCompositeMethodBodies = ExcludedInterfacesForCompositeMethods.Contains(interfaceName) + var extensionCompositeMethodBodies = ExcludedInterfacesForCompositeMethods.TryGetValue(interfaceName, out var exclusionReason) ? [ $""" // Composite methods are not supported for {interfaceName} - // because a composite acquire operation must be able to roll back and upgrade does not support that. + // because {exclusionReason} """ ] : interfaceMethods From 7afa7af141e6d9199774f5e693157476bc15f24e Mon Sep 17 00:00:00 2001 From: Moein Nemati Date: Mon, 1 Dec 2025 23:32:45 +0330 Subject: [PATCH 12/12] Remove extra space --- src/DistributedLock.Core/DistributedLockProviderExtensions.cs | 2 +- .../DistributedReaderWriterLockProviderExtensions.cs | 2 +- .../DistributedSemaphoreProviderExtensions.cs | 2 +- .../DistributedUpgradeableReaderWriterLockProviderExtensions.cs | 2 +- src/DistributedLockCodeGen/GenerateProviders.cs | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/DistributedLock.Core/DistributedLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedLockProviderExtensions.cs index bc98c6bb..42dd78f1 100644 --- a/src/DistributedLock.Core/DistributedLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedLockProviderExtensions.cs @@ -38,7 +38,7 @@ public static ValueTask AcquireLockAsync(this (provider ?? throw new ArgumentNullException(nameof(provider))).CreateLock(name).AcquireAsync(timeout, cancellationToken); # endregion - + # region Composite Lock Methods /// diff --git a/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs index da5f407c..5e2dadd2 100644 --- a/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs @@ -66,7 +66,7 @@ public static ValueTask AcquireWriteLockAsync (provider ?? throw new ArgumentNullException(nameof(provider))).CreateReaderWriterLock(name).AcquireWriteLockAsync(timeout, cancellationToken); # endregion - + # region Composite Lock Methods /// diff --git a/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs b/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs index 3ed29ebe..dc961498 100644 --- a/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs @@ -38,7 +38,7 @@ public static ValueTask AcquireSemaphoreAsync (provider ?? throw new ArgumentNullException(nameof(provider))).CreateSemaphore(name, maxCount).AcquireAsync(timeout, cancellationToken); # endregion - + # region Composite Lock Methods /// diff --git a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs index 22884c62..1f647384 100644 --- a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs @@ -38,7 +38,7 @@ public static ValueTask AcquireUpgradeableRea (provider ?? throw new ArgumentNullException(nameof(provider))).CreateUpgradeableReaderWriterLock(name).AcquireUpgradeableReadLockAsync(timeout, cancellationToken); # endregion - + # region Composite Lock Methods // Composite methods are not supported for IDistributedUpgradeableReaderWriterLock diff --git a/src/DistributedLockCodeGen/GenerateProviders.cs b/src/DistributedLockCodeGen/GenerateProviders.cs index d00aa8c8..b55e0aa8 100644 --- a/src/DistributedLockCodeGen/GenerateProviders.cs +++ b/src/DistributedLockCodeGen/GenerateProviders.cs @@ -111,7 +111,7 @@ public static class {{providerExtensionsName}} {{string.Join(Environment.NewLine + Environment.NewLine, extensionSingleMethodBodies)}} # endregion - + # region Composite Lock Methods {{string.Join(Environment.NewLine + Environment.NewLine, extensionCompositeMethodBodies)}}