File: Middleware\ConnectionLimitMiddleware.cs
Web Access
Project: src\src\Servers\Kestrel\Core\src\Microsoft.AspNetCore.Server.Kestrel.Core.csproj (Microsoft.AspNetCore.Server.Kestrel.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
 
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
 
internal sealed class ConnectionLimitMiddleware<T> where T : BaseConnectionContext
{
    private readonly Func<T, Task> _next;
    private readonly ResourceCounter _concurrentConnectionCounter;
    private readonly KestrelTrace _trace;
    private readonly KestrelMetrics _metrics;
 
    public ConnectionLimitMiddleware(Func<T, Task> next, long connectionLimit, KestrelTrace trace, KestrelMetrics metrics)
        : this(next, ResourceCounter.Quota(connectionLimit), trace, metrics)
    {
    }
 
    // For Testing
    internal ConnectionLimitMiddleware(Func<T, Task> next, ResourceCounter concurrentConnectionCounter, KestrelTrace trace, KestrelMetrics metrics)
    {
        _next = next;
        _concurrentConnectionCounter = concurrentConnectionCounter;
        _trace = trace;
        _metrics = metrics;
    }
 
    public async Task OnConnectionAsync(T connection)
    {
        if (!_concurrentConnectionCounter.TryLockOne())
        {
            KestrelEventSource.Log.ConnectionRejected(connection.ConnectionId);
            _trace.ConnectionRejected(connection.ConnectionId);
            _metrics.ConnectionRejected(connection.Features.GetRequiredFeature<IConnectionMetricsContextFeature>().MetricsContext);
            await connection.DisposeAsync();
            return;
        }
 
        var releasor = new ConnectionReleasor(_concurrentConnectionCounter);
 
        try
        {
            connection.Features.Set<IDecrementConcurrentConnectionCountFeature>(releasor);
            await _next(connection);
        }
        finally
        {
            releasor.ReleaseConnection();
        }
    }
 
    private sealed class ConnectionReleasor : IDecrementConcurrentConnectionCountFeature
    {
        private readonly ResourceCounter _concurrentConnectionCounter;
        private bool _connectionReleased;
 
        public ConnectionReleasor(ResourceCounter normalConnectionCounter)
        {
            _concurrentConnectionCounter = normalConnectionCounter;
        }
 
        public void ReleaseConnection()
        {
            if (!_connectionReleased)
            {
                _connectionReleased = true;
                _concurrentConnectionCounter.ReleaseOne();
            }
        }
    }
}