File: TestUtils.cs
Web Access
Project: src\src\Middleware\ConcurrencyLimiter\test\Microsoft.AspNetCore.ConcurrencyLimiter.Tests.csproj (Microsoft.AspNetCore.ConcurrencyLimiter.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
 
namespace Microsoft.AspNetCore.ConcurrencyLimiter.Tests;
 
public static class TestUtils
{
#pragma warning disable CS0618 // Type or member is obsolete
    public static ConcurrencyLimiterMiddleware CreateTestMiddleware(IQueuePolicy queue = null, RequestDelegate onRejected = null, RequestDelegate next = null)
    {
        var options = Options.Create(new ConcurrencyLimiterOptions
        {
            OnRejected = onRejected ?? (context => Task.CompletedTask),
        });
 
        return new ConcurrencyLimiterMiddleware(
                next: next ?? (context => Task.CompletedTask),
                loggerFactory: NullLoggerFactory.Instance,
                queue: queue ?? CreateQueuePolicy(1, 0),
                options: options
            );
    }
 
    public static ConcurrencyLimiterMiddleware CreateTestMiddleware_QueuePolicy(int maxConcurrentRequests, int requestQueueLimit, RequestDelegate onRejected = null, RequestDelegate next = null)
    {
        return CreateTestMiddleware(
            queue: CreateQueuePolicy(maxConcurrentRequests, requestQueueLimit),
            onRejected: onRejected,
            next: next
            );
    }
 
    public static ConcurrencyLimiterMiddleware CreateTestMiddleware_StackPolicy(int maxConcurrentRequests, int requestQueueLimit, RequestDelegate onRejected = null, RequestDelegate next = null)
    {
        return CreateTestMiddleware(
            queue: CreateStackPolicy(maxConcurrentRequests, requestQueueLimit),
            onRejected: onRejected,
            next: next
            );
    }
 
    internal static StackPolicy CreateStackPolicy(int maxConcurrentRequests, int requestsQueuelimit = 100)
    {
        var options = Options.Create(new QueuePolicyOptions
        {
            MaxConcurrentRequests = maxConcurrentRequests,
            RequestQueueLimit = requestsQueuelimit
        });
 
        return new StackPolicy(options);
    }
 
    internal static QueuePolicy CreateQueuePolicy(int maxConcurrentRequests, int requestQueueLimit = 100)
    {
        var options = Options.Create(new QueuePolicyOptions
        {
            MaxConcurrentRequests = maxConcurrentRequests,
            RequestQueueLimit = requestQueueLimit
        });
 
        return new QueuePolicy(options);
    }
#pragma warning restore CS0618 // Type or member is obsolete
}
 
internal class TestQueue : IQueuePolicy
{
    private Func<TestQueue, Task<bool>> _onTryEnter { get; }
    private Action _onExit { get; }
 
    private int _queuedRequests;
    public int QueuedRequests { get => _queuedRequests; }
 
    public TestQueue(Func<TestQueue, Task<bool>> onTryEnter, Action onExit = null)
    {
        _onTryEnter = onTryEnter;
        _onExit = onExit ?? (() => { });
    }
 
    public TestQueue(Func<TestQueue, bool> onTryEnter, Action onExit = null) :
        this(state => Task.FromResult(onTryEnter(state))
        , onExit)
    { }
 
    public async ValueTask<bool> TryEnterAsync()
    {
        Interlocked.Increment(ref _queuedRequests);
        var result = await _onTryEnter(this);
        Interlocked.Decrement(ref _queuedRequests);
        return result;
    }
 
    public void OnExit()
    {
        _onExit();
    }
 
    public static TestQueue AlwaysFalse =
        new TestQueue((_) => false);
 
    public static TestQueue AlwaysTrue =
        new TestQueue((_) => true);
 
    public static TestQueue AlwaysBlock =
        new TestQueue(async (_) =>
        {
            await new SemaphoreSlim(0).WaitAsync();
            return false;
        });
}