File: MiddlewareTests.cs
Web Access
Project: src\src\Middleware\ConcurrencyLimiter\test\Microsoft.AspNetCore.ConcurrencyLimiter.Tests.csproj (Microsoft.AspNetCore.ConcurrencyLimiter.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Threading.Tasks.Sources;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.InternalTesting;
 
namespace Microsoft.AspNetCore.ConcurrencyLimiter.Tests;
 
public class MiddlewareTests
{
    [Fact]
    public async Task RequestsCallNextIfQueueReturnsTrue()
    {
        var flag = false;
 
        var middleware = TestUtils.CreateTestMiddleware(
            queue: TestQueue.AlwaysTrue,
            next: httpContext =>
            {
                flag = true;
                return Task.CompletedTask;
            });
 
        await middleware.Invoke(new DefaultHttpContext());
        Assert.True(flag);
    }
 
    [Fact]
    public async Task RequestRejectsIfQueueReturnsFalse()
    {
        bool onRejectedInvoked = false;
 
        var middleware = TestUtils.CreateTestMiddleware(
            queue: TestQueue.AlwaysFalse,
            onRejected: httpContext =>
            {
                onRejectedInvoked = true;
                return Task.CompletedTask;
            });
 
        var context = new DefaultHttpContext();
        await middleware.Invoke(context).DefaultTimeout();
        Assert.True(onRejectedInvoked);
        Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode);
    }
 
    [Fact]
    public async Task RequestsDoesNotEnterIfQueueFull()
    {
        var middleware = TestUtils.CreateTestMiddleware(
            queue: TestQueue.AlwaysFalse,
            next: httpContext =>
            {
                // throttle should bounce the request; it should never get here
                throw new DivideByZeroException();
            });
 
        await middleware.Invoke(new DefaultHttpContext()).DefaultTimeout();
    }
 
    [Fact]
    public void IncomingRequestsFillUpQueue()
    {
        var testQueue = TestQueue.AlwaysBlock;
        var middleware = TestUtils.CreateTestMiddleware(testQueue);
 
        Assert.Equal(0, testQueue.QueuedRequests);
 
        var task1 = middleware.Invoke(new DefaultHttpContext());
        Assert.Equal(1, testQueue.QueuedRequests);
        Assert.False(task1.IsCompleted);
 
        var task2 = middleware.Invoke(new DefaultHttpContext());
        Assert.Equal(2, testQueue.QueuedRequests);
        Assert.False(task2.IsCompleted);
    }
 
    [Fact]
    public void EventCountersTrackQueuedRequests()
    {
        var blocker = new TaskCompletionSource<bool>();
 
        var testQueue = new TestQueue(
            onTryEnter: async (_) =>
            {
                return await blocker.Task;
            });
        var middleware = TestUtils.CreateTestMiddleware(testQueue);
 
        Assert.Equal(0, testQueue.QueuedRequests);
 
        var task1 = middleware.Invoke(new DefaultHttpContext());
        Assert.False(task1.IsCompleted);
        Assert.Equal(1, testQueue.QueuedRequests);
 
        blocker.SetResult(true);
 
        Assert.Equal(0, testQueue.QueuedRequests);
    }
 
    [Fact]
    public async Task QueueOnExitCalledEvenIfNextErrors()
    {
        var flag = false;
 
        var testQueue = new TestQueue(
                onTryEnter: (_) => true,
                onExit: () => { flag = true; });
 
        var middleware = TestUtils.CreateTestMiddleware(
            queue: testQueue,
            next: httpContext =>
            {
                throw new DivideByZeroException();
            });
 
        Assert.Equal(0, testQueue.QueuedRequests);
        await Assert.ThrowsAsync<DivideByZeroException>(() => middleware.Invoke(new DefaultHttpContext())).DefaultTimeout();
 
        Assert.Equal(0, testQueue.QueuedRequests);
        Assert.True(flag);
    }
 
    [Fact]
    public async Task ExceptionThrownDuringOnRejected()
    {
        TaskCompletionSource tcs = new TaskCompletionSource();
 
        var concurrent = 0;
        var testQueue = new TestQueue(
            onTryEnter: (testQueue) =>
            {
                if (concurrent > 0)
                {
                    return false;
                }
                else
                {
                    concurrent++;
                    return true;
                }
            },
            onExit: () => { concurrent--; });
 
        var middleware = TestUtils.CreateTestMiddleware(
            queue: testQueue,
            onRejected: httpContext =>
            {
                throw new DivideByZeroException();
            },
            next: httpContext =>
            {
                return tcs.Task;
            });
 
        // the first request enters the server, and is blocked by the tcs
        var firstRequest = middleware.Invoke(new DefaultHttpContext());
        Assert.Equal(1, concurrent);
        Assert.Equal(0, testQueue.QueuedRequests);
 
        // the second request is rejected with a 503 error. During the rejection, an error occurs
        var context = new DefaultHttpContext();
        await Assert.ThrowsAsync<DivideByZeroException>(() => middleware.Invoke(context)).DefaultTimeout();
        Assert.Equal(StatusCodes.Status503ServiceUnavailable, context.Response.StatusCode);
        Assert.Equal(1, concurrent);
        Assert.Equal(0, testQueue.QueuedRequests);
 
        // the first request is unblocked, and the queue continues functioning as expected
        tcs.SetResult();
        Assert.True(firstRequest.IsCompletedSuccessfully);
        Assert.Equal(0, concurrent);
        Assert.Equal(0, testQueue.QueuedRequests);
 
        var thirdRequest = middleware.Invoke(new DefaultHttpContext());
        Assert.True(thirdRequest.IsCompletedSuccessfully);
        Assert.Equal(0, concurrent);
        Assert.Equal(0, testQueue.QueuedRequests);
    }
 
    [Fact]
    public async Task MiddlewareOnlyCallsGetResultOnce()
    {
        var flag = false;
 
        var queue = new TestQueueForValueTask();
        var middleware = TestUtils.CreateTestMiddleware(
            queue,
            next: async context =>
            {
                await Task.CompletedTask;
                flag = true;
            });
 
        await middleware.Invoke(new DefaultHttpContext());
 
        Assert.True(flag);
    }
 
    private class TestQueueForValueTask : IQueuePolicy
    {
        public TestValueResult Source;
        public TestQueueForValueTask()
        {
            Source = new TestValueResult();
        }
 
        public ValueTask<bool> TryEnterAsync()
        {
            return new ValueTask<bool>(Source, 0);
        }
 
        public void OnExit() { }
    }
 
    private class TestValueResult : IValueTaskSource<bool>
    {
        private bool _getResultCalled;
 
        public bool GetResult(short token)
        {
            Assert.False(_getResultCalled);
            _getResultCalled = true;
            return true;
        }
 
        public ValueTaskSourceStatus GetStatus(short token)
        {
            return ValueTaskSourceStatus.Succeeded;
        }
 
        public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
        {
            throw new NotImplementedException();
        }
    }
}