|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.ObjectPool;
namespace System.Runtime
{
internal class AsyncLock : IAsyncDisposable
{
private static readonly ObjectPool<SemaphoreSlim> s_semaphorePool = new DefaultObjectPool<SemaphoreSlim>(new SemaphoreSlimPooledObjectPolicy(), 100);
private AsyncLocal<SemaphoreSlim> _currentSemaphore;
private SemaphoreSlim _topLevelSemaphore;
private bool _isDisposed;
public AsyncLock()
{
_topLevelSemaphore = s_semaphorePool.Get();
_currentSemaphore = new AsyncLocal<SemaphoreSlim>();
}
public Task<IAsyncDisposable> TakeLockAsync()
{
if (_isDisposed)
throw new ObjectDisposedException(nameof(AsyncLock));
_currentSemaphore.Value ??= _topLevelSemaphore;
SemaphoreSlim currentSem = _currentSemaphore.Value;
var nextSem = s_semaphorePool.Get();
_currentSemaphore.Value = nextSem;
var safeRelease = new SafeSemaphoreRelease(currentSem, nextSem, this);
return TakeLockCoreAsync(currentSem, safeRelease);
}
private async Task<IAsyncDisposable> TakeLockCoreAsync(SemaphoreSlim currentSemaphore, SafeSemaphoreRelease safeSemaphoreRelease)
{
await currentSemaphore.WaitAsync();
return safeSemaphoreRelease;
}
public IDisposable TakeLock()
{
if (_isDisposed)
throw new ObjectDisposedException(nameof(AsyncLock));
_currentSemaphore.Value ??= _topLevelSemaphore;
SemaphoreSlim currentSem = _currentSemaphore.Value;
currentSem.Wait();
var nextSem = s_semaphorePool.Get();
_currentSemaphore.Value = nextSem;
return new SafeSemaphoreRelease(currentSem, nextSem, this);
}
public async ValueTask DisposeAsync()
{
if (_isDisposed)
return;
_isDisposed = true;
// Ensure the lock isn't held. If it is, wait for it to be released
// before completing the dispose.
await _topLevelSemaphore.WaitAsync();
_topLevelSemaphore.Release();
s_semaphorePool.Return(_topLevelSemaphore);
_topLevelSemaphore = null;
}
private struct SafeSemaphoreRelease : IAsyncDisposable, IDisposable
{
private SemaphoreSlim _currentSemaphore;
private SemaphoreSlim _nextSemaphore;
private AsyncLock _asyncLock;
public SafeSemaphoreRelease(SemaphoreSlim currentSemaphore, SemaphoreSlim nextSemaphore, AsyncLock asyncLock)
{
_currentSemaphore = currentSemaphore;
_nextSemaphore = nextSemaphore;
_asyncLock = asyncLock;
}
public ValueTask DisposeAsync()
{
Fx.Assert(_nextSemaphore == _asyncLock._currentSemaphore.Value, "_nextSemaphore was expected to by the current semaphore");
// Update _asyncLock._currentSemaphore in the calling ExecutionContext
// and defer any awaits to DisposeCoreAsync(). If this isn't done, the
// update will happen in a copy of the ExecutionContext and the caller
// won't see the changes.
if (_currentSemaphore == _asyncLock._topLevelSemaphore)
{
_asyncLock._currentSemaphore.Value = null;
}
else
{
_asyncLock._currentSemaphore.Value = _currentSemaphore;
}
return DisposeCoreAsync();
}
private async ValueTask DisposeCoreAsync()
{
await _nextSemaphore.WaitAsync();
_currentSemaphore.Release();
_nextSemaphore.Release();
s_semaphorePool.Return(_nextSemaphore);
}
public void Dispose()
{
Fx.Assert(_nextSemaphore == _asyncLock._currentSemaphore.Value, "_nextSemaphore was expected to by the current semaphore");
if (_currentSemaphore == _asyncLock._topLevelSemaphore)
{
_asyncLock._currentSemaphore.Value = null;
}
else
{
_asyncLock._currentSemaphore.Value = _currentSemaphore;
}
_nextSemaphore.Wait();
_currentSemaphore.Release();
_nextSemaphore.Release();
s_semaphorePool.Return(_nextSemaphore);
}
}
private class SemaphoreSlimPooledObjectPolicy : PooledObjectPolicy<SemaphoreSlim>
{
public override SemaphoreSlim Create()
{
return new SemaphoreSlim(1);
}
public override bool Return(SemaphoreSlim obj)
{
if (obj.CurrentCount != 1)
{
Fx.Assert("Shouldn't be returning semaphore with a count != 1");
return false;
}
return true;
}
}
}
}
|