File: RateLimiterOptions.cs
Web Access
Project: src\src\Middleware\RateLimiting\src\Microsoft.AspNetCore.RateLimiting.csproj (Microsoft.AspNetCore.RateLimiting)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics.CodeAnalysis;
using System.Threading.RateLimiting;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
 
namespace Microsoft.AspNetCore.RateLimiting;
 
/// <summary>
/// Specifies options for the rate limiting middleware.
/// </summary>
public sealed class RateLimiterOptions
{
    internal Dictionary<string, DefaultRateLimiterPolicy> PolicyMap { get; }
        = new Dictionary<string, DefaultRateLimiterPolicy>(StringComparer.Ordinal);
 
    internal Dictionary<string, Func<IServiceProvider, DefaultRateLimiterPolicy>> UnactivatedPolicyMap { get; }
        = new Dictionary<string, Func<IServiceProvider, DefaultRateLimiterPolicy>>(StringComparer.Ordinal);
 
    /// <summary>
    /// Gets or sets the global <see cref="PartitionedRateLimiter{HttpContext}"/> that will be applied on all requests.
    /// The global limiter will be executed first, followed by the endpoint-specific limiter, if one exists.
    /// </summary>
    public PartitionedRateLimiter<HttpContext>? GlobalLimiter { get; set; }
 
    /// <summary>
    /// Gets or sets a <see cref="Func{OnRejectedContext, CancellationToken, ValueTask}"/> that handles requests rejected by this middleware.
    /// </summary>
    public Func<OnRejectedContext, CancellationToken, ValueTask>? OnRejected { get; set; }
 
    /// <summary>
    /// Gets or sets the default status code to set on the response when a request is rejected.
    /// Defaults to <see cref="StatusCodes.Status503ServiceUnavailable"/>.
    /// </summary>
    /// <remarks>
    /// This status code will be set before <see cref="OnRejected"/> is called, so any status code set by
    /// <see cref="OnRejected"/> will "win" over this default.
    /// </remarks>
    public int RejectionStatusCode { get; set; } = StatusCodes.Status503ServiceUnavailable;
 
    /// <summary>
    /// Adds a new rate limiting policy with the given <paramref name="policyName"/>
    /// </summary>
    /// <param name="policyName">The name to be associated with the given <see cref="RateLimiter"/>.</param>
    /// <param name="partitioner">Method called every time an Acquire or WaitAsync call is made to determine what rate limiter to apply to the request.</param>
    public RateLimiterOptions AddPolicy<TPartitionKey>(string policyName, Func<HttpContext, RateLimitPartition<TPartitionKey>> partitioner)
    {
        ArgumentNullException.ThrowIfNull(policyName);
        ArgumentNullException.ThrowIfNull(partitioner);
 
        if (PolicyMap.ContainsKey(policyName) || UnactivatedPolicyMap.ContainsKey(policyName))
        {
            throw new ArgumentException($"There already exists a policy with the name {policyName}.", nameof(policyName));
        }
 
        PolicyMap.Add(policyName, new DefaultRateLimiterPolicy(ConvertPartitioner<TPartitionKey>(policyName, partitioner), null));
 
        return this;
    }
 
    /// <summary>
    /// Adds a new rate limiting policy with the given policyName.
    /// </summary>
    /// <param name="policyName">The name to be associated with the given TPolicy.</param>
    public RateLimiterOptions AddPolicy<TPartitionKey, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TPolicy>(string policyName) where TPolicy : IRateLimiterPolicy<TPartitionKey>
    {
        ArgumentNullException.ThrowIfNull(policyName);
 
        if (PolicyMap.ContainsKey(policyName) || UnactivatedPolicyMap.ContainsKey(policyName))
        {
            throw new ArgumentException($"There already exists a policy with the name {policyName}.", nameof(policyName));
        }
 
        Func<IServiceProvider, DefaultRateLimiterPolicy> policyFunc = serviceProvider =>
        {
            var instance = (IRateLimiterPolicy<TPartitionKey>)ActivatorUtilities.CreateInstance(serviceProvider, typeof(TPolicy));
            return new DefaultRateLimiterPolicy(ConvertPartitioner<TPartitionKey>(policyName, instance.GetPartition), instance.OnRejected);
        };
 
        UnactivatedPolicyMap.Add(policyName, policyFunc);
 
        return this;
    }
 
    /// <summary>
    /// Adds a new rate limiting policy with the given policyName.
    /// </summary>
    /// <param name="policyName">The name to be associated with the given <see cref="IRateLimiterPolicy{TPartitionKey}"/>.</param>
    /// <param name="policy">The <see cref="IRateLimiterPolicy{TPartitionKey}"/> to be applied.</param>
    public RateLimiterOptions AddPolicy<TPartitionKey>(string policyName, IRateLimiterPolicy<TPartitionKey> policy)
    {
        ArgumentNullException.ThrowIfNull(policyName);
 
        if (PolicyMap.ContainsKey(policyName) || UnactivatedPolicyMap.ContainsKey(policyName))
        {
            throw new ArgumentException($"There already exists a policy with the name {policyName}.", nameof(policyName));
        }
 
        ArgumentNullException.ThrowIfNull(policy);
 
        PolicyMap.Add(policyName, new DefaultRateLimiterPolicy(ConvertPartitioner<TPartitionKey>(policyName, policy.GetPartition), policy.OnRejected));
 
        return this;
    }
 
    // Converts a Partition<TKey> to a Partition<DefaultKeyType<TKey>> to prevent accidental collisions with the keys we create in the the RateLimiterOptionsExtensions.
    internal static Func<HttpContext, RateLimitPartition<DefaultKeyType>> ConvertPartitioner<TPartitionKey>(string? policyName, Func<HttpContext, RateLimitPartition<TPartitionKey>> partitioner)
    {
        return context =>
        {
            RateLimitPartition<TPartitionKey> partition = partitioner(context);
            var partitionKey = new DefaultKeyType(policyName, partition.PartitionKey, partition.Factory);
            return new RateLimitPartition<DefaultKeyType>(partitionKey, static key => ((Func<TPartitionKey, RateLimiter>)key.Factory!)((TPartitionKey)key.Key!));
        };
    }
}