File: RateLimitingMiddleware.cs
Web Access
Project: src\src\Middleware\RateLimiting\src\Microsoft.AspNetCore.RateLimiting.csproj (Microsoft.AspNetCore.RateLimiting)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Threading.RateLimiting;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
 
namespace Microsoft.AspNetCore.RateLimiting;
 
/// <summary>
/// Limits the rate of requests allowed in the application, based on limits set by a user-provided <see cref="PartitionedRateLimiter{TResource}"/>.
/// </summary>
internal sealed partial class RateLimitingMiddleware
{
    private readonly RequestDelegate _next;
    private readonly Func<OnRejectedContext, CancellationToken, ValueTask>? _defaultOnRejected;
    private readonly ILogger _logger;
    private readonly RateLimitingMetrics _metrics;
    private readonly PartitionedRateLimiter<HttpContext>? _globalLimiter;
    private readonly PartitionedRateLimiter<HttpContext> _endpointLimiter;
    private readonly int _rejectionStatusCode;
    private readonly Dictionary<string, DefaultRateLimiterPolicy> _policyMap;
    private readonly DefaultKeyType _defaultPolicyKey = new DefaultKeyType("__defaultPolicy", new PolicyNameKey { PolicyName = "__defaultPolicyKey" });
 
    /// <summary>
    /// Creates a new <see cref="RateLimitingMiddleware"/>.
    /// </summary>
    /// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
    /// <param name="logger">The <see cref="ILogger"/> used for logging.</param>
    /// <param name="options">The options for the middleware.</param>
    /// <param name="serviceProvider">The service provider.</param>
    /// <param name="metrics">The rate limiting metrics.</param>
    public RateLimitingMiddleware(RequestDelegate next, ILogger<RateLimitingMiddleware> logger, IOptions<RateLimiterOptions> options, IServiceProvider serviceProvider, RateLimitingMetrics metrics)
    {
        ArgumentNullException.ThrowIfNull(next);
        ArgumentNullException.ThrowIfNull(logger);
        ArgumentNullException.ThrowIfNull(serviceProvider);
        ArgumentNullException.ThrowIfNull(metrics);
 
        _next = next;
        _logger = logger;
        _metrics = metrics;
        _defaultOnRejected = options.Value.OnRejected;
        _rejectionStatusCode = options.Value.RejectionStatusCode;
        _policyMap = new Dictionary<string, DefaultRateLimiterPolicy>(options.Value.PolicyMap);
 
        // Activate policies passed to AddPolicy<TPartitionKey, TPolicy>
        foreach (var unactivatedPolicy in options.Value.UnactivatedPolicyMap)
        {
            _policyMap.Add(unactivatedPolicy.Key, unactivatedPolicy.Value(serviceProvider));
        }
 
        _globalLimiter = options.Value.GlobalLimiter;
        _endpointLimiter = CreateEndpointLimiter();
    }
 
    // TODO - EventSource?
    /// <summary>
    /// Invokes the logic of the middleware.
    /// </summary>
    /// <param name="context">The <see cref="HttpContext"/>.</param>
    /// <returns>A <see cref="Task"/> that completes when the request leaves.</returns>
    public Task Invoke(HttpContext context)
    {
        var endpoint = context.GetEndpoint();
        // If this endpoint has a DisableRateLimitingAttribute, don't apply any rate limits.
        if (endpoint?.Metadata.GetMetadata<DisableRateLimitingAttribute>() is not null)
        {
            return _next(context);
        }
        var enableRateLimitingAttribute = endpoint?.Metadata.GetMetadata<EnableRateLimitingAttribute>();
        // If this endpoint has no EnableRateLimitingAttribute & there's no global limiter, don't apply any rate limits.
        if (enableRateLimitingAttribute is null && _globalLimiter is null)
        {
            return _next(context);
        }
 
        return InvokeInternal(context, enableRateLimitingAttribute);
    }
 
    private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribute? enableRateLimitingAttribute)
    {
        var policyName = enableRateLimitingAttribute?.PolicyName;
 
        // Cache the up/down counter enabled state at the start of the middleware.
        // This ensures that the state is consistent for the entire request.
        // For example, if a meter listener starts after a request is queued, when the request exits the queue
        // the requests queued counter won't go into a negative value.
        var metricsContext = _metrics.CreateContext(policyName);
 
        using var leaseContext = await TryAcquireAsync(context, metricsContext);
 
        if (leaseContext.Lease?.IsAcquired == true)
        {
            var startTimestamp = Stopwatch.GetTimestamp();
            var currentLeaseStart = metricsContext.CurrentLeasedRequestsCounterEnabled;
            try
            {
 
                _metrics.LeaseStart(metricsContext);
                await _next(context);
            }
            finally
            {
                _metrics.LeaseEnd(metricsContext, startTimestamp, Stopwatch.GetTimestamp());
            }
        }
        else
        {
            _metrics.LeaseFailed(metricsContext, leaseContext.RequestRejectionReason!.Value);
 
            // If the request was canceled, do not call OnRejected, just return.
            if (leaseContext.RequestRejectionReason == RequestRejectionReason.RequestCanceled)
            {
                return;
            }
            var thisRequestOnRejected = _defaultOnRejected;
            RateLimiterLog.RequestRejectedLimitsExceeded(_logger);
            // OnRejected "wins" over DefaultRejectionStatusCode - we set DefaultRejectionStatusCode first,
            // then call OnRejected in case it wants to do any further modification of the status code.
            context.Response.StatusCode = _rejectionStatusCode;
 
            // If this request was rejected by the endpoint limiter, use its OnRejected if available.
            if (leaseContext.RequestRejectionReason == RequestRejectionReason.EndpointLimiter)
            {
                DefaultRateLimiterPolicy? policy;
                // Use custom policy OnRejected if available, else use OnRejected from the Options if available.
                policy = enableRateLimitingAttribute?.Policy;
                if (policy is not null)
                {
                    thisRequestOnRejected = policy.OnRejected;
                }
                else
                {
                    if (policyName is not null && _policyMap.TryGetValue(policyName, out policy) && policy.OnRejected is not null)
                    {
                        thisRequestOnRejected = policy.OnRejected;
                    }
                }
            }
            if (thisRequestOnRejected is not null)
            {
                // leaseContext.Lease will only be null when the request was canceled.
                await thisRequestOnRejected(new OnRejectedContext() { HttpContext = context, Lease = leaseContext.Lease! }, context.RequestAborted);
            }
        }
    }
 
    private async ValueTask<LeaseContext> TryAcquireAsync(HttpContext context, MetricsContext metricsContext)
    {
        var leaseContext = CombinedAcquire(context);
        if (leaseContext.Lease?.IsAcquired == true)
        {
            return leaseContext;
        }
 
        var waitTask = CombinedWaitAsync(context, context.RequestAborted);
        // If the task returns immediately then the request wasn't queued.
        if (waitTask.IsCompleted)
        {
            return await waitTask;
        }
 
        var startTimestamp = Stopwatch.GetTimestamp();
        try
        {
            _metrics.QueueStart(metricsContext);
            leaseContext = await waitTask;
            return leaseContext;
        }
        finally
        {
            _metrics.QueueEnd(metricsContext, leaseContext.RequestRejectionReason, startTimestamp, Stopwatch.GetTimestamp());
        }
    }
 
    private LeaseContext CombinedAcquire(HttpContext context)
    {
        RateLimitLease? globalLease = null;
        RateLimitLease? endpointLease = null;
 
        try
        {
            if (_globalLimiter is not null)
            {
                globalLease = _globalLimiter.AttemptAcquire(context);
                if (!globalLease.IsAcquired)
                {
                    return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.GlobalLimiter, Lease = globalLease };
                }
            }
            endpointLease = _endpointLimiter.AttemptAcquire(context);
            if (!endpointLease.IsAcquired)
            {
                globalLease?.Dispose();
                return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.EndpointLimiter, Lease = endpointLease };
            }
        }
        catch (Exception)
        {
            endpointLease?.Dispose();
            globalLease?.Dispose();
            throw;
        }
        return globalLease is null ? new LeaseContext() { Lease = endpointLease } : new LeaseContext() { Lease = new DefaultCombinedLease(globalLease, endpointLease) };
    }
 
    private async ValueTask<LeaseContext> CombinedWaitAsync(HttpContext context, CancellationToken cancellationToken)
    {
        RateLimitLease? globalLease = null;
        RateLimitLease? endpointLease = null;
 
        try
        {
            if (_globalLimiter is not null)
            {
                globalLease = await _globalLimiter.AcquireAsync(context, cancellationToken: cancellationToken);
                if (!globalLease.IsAcquired)
                {
                    return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.GlobalLimiter, Lease = globalLease };
                }
            }
            endpointLease = await _endpointLimiter.AcquireAsync(context, cancellationToken: cancellationToken);
            if (!endpointLease.IsAcquired)
            {
                globalLease?.Dispose();
                return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.EndpointLimiter, Lease = endpointLease };
            }
        }
        catch (Exception ex)
        {
            endpointLease?.Dispose();
            globalLease?.Dispose();
            // Don't throw if the request was canceled - instead log. 
            if (ex is OperationCanceledException && context.RequestAborted.IsCancellationRequested)
            {
                RateLimiterLog.RequestCanceled(_logger);
                return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.RequestCanceled };
            }
            else
            {
                throw;
            }
        }
 
        return globalLease is null ? new LeaseContext() { Lease = endpointLease } : new LeaseContext() { Lease = new DefaultCombinedLease(globalLease, endpointLease) };
    }
 
    // Create the endpoint-specific PartitionedRateLimiter
    private PartitionedRateLimiter<HttpContext> CreateEndpointLimiter()
    {
        // If we have a policy for this endpoint, use its partitioner. Else use a NoLimiter.
        return PartitionedRateLimiter.Create<HttpContext, DefaultKeyType>(context =>
        {
            DefaultRateLimiterPolicy? policy;
            var enableRateLimitingAttribute = context.GetEndpoint()?.Metadata.GetMetadata<EnableRateLimitingAttribute>();
            if (enableRateLimitingAttribute is null)
            {
                return RateLimitPartition.GetNoLimiter<DefaultKeyType>(_defaultPolicyKey);
            }
            policy = enableRateLimitingAttribute.Policy;
            if (policy is not null)
            {
                return policy.GetPartition(context);
            }
            var name = enableRateLimitingAttribute.PolicyName;
            if (name is not null)
            {
                if (_policyMap.TryGetValue(name, out policy))
                {
                    return policy.GetPartition(context);
                }
                else
                {
                    throw new InvalidOperationException($"This endpoint requires a rate limiting policy with name {name}, but no such policy exists.");
                }
            }
            // Should be impossible for both name & policy to be null, but throw in that scenario just in case.
            else
            {
                throw new InvalidOperationException("This endpoint requested a rate limiting policy with a null name.");
            }
        }, new DefaultKeyTypeEqualityComparer());
    }
 
    private static partial class RateLimiterLog
    {
        [LoggerMessage(1, LogLevel.Debug, "Rate limits exceeded, rejecting this request.", EventName = "RequestRejectedLimitsExceeded")]
        internal static partial void RequestRejectedLimitsExceeded(ILogger logger);
 
        [LoggerMessage(2, LogLevel.Debug, "This endpoint requires a rate limiting policy with name {PolicyName}, but no such policy exists.", EventName = "WarnMissingPolicy")]
        internal static partial void WarnMissingPolicy(ILogger logger, string policyName);
 
        [LoggerMessage(3, LogLevel.Debug, "The request was canceled.", EventName = "RequestCanceled")]
        internal static partial void RequestCanceled(ILogger logger);
    }
}