File: EndpointRoutingMiddleware.cs
Web Access
Project: src\src\Http\Routing\src\Microsoft.AspNetCore.Routing.csproj (Microsoft.AspNetCore.Routing)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Runtime.CompilerServices;
using Microsoft.AspNetCore.Antiforgery;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Cors.Infrastructure;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Metadata;
using Microsoft.AspNetCore.Routing.Matching;
using Microsoft.AspNetCore.Routing.ShortCircuit;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
 
namespace Microsoft.AspNetCore.Routing;
 
internal sealed partial class EndpointRoutingMiddleware
{
    private const string DiagnosticsEndpointMatchedKey = "Microsoft.AspNetCore.Routing.EndpointMatched";
 
    private readonly MatcherFactory _matcherFactory;
    private readonly ILogger _logger;
    private readonly EndpointDataSource _endpointDataSource;
    private readonly DiagnosticListener _diagnosticListener;
    private readonly RoutingMetrics _metrics;
    private readonly RequestDelegate _next;
    private readonly RouteOptions _routeOptions;
    private Task<Matcher>? _initializationTask;
 
    public EndpointRoutingMiddleware(
        MatcherFactory matcherFactory,
        ILogger<EndpointRoutingMiddleware> logger,
        IEndpointRouteBuilder endpointRouteBuilder,
        EndpointDataSource rootCompositeEndpointDataSource,
        DiagnosticListener diagnosticListener,
        IOptions<RouteOptions> routeOptions,
        RoutingMetrics metrics,
        RequestDelegate next)
    {
        ArgumentNullException.ThrowIfNull(endpointRouteBuilder);
 
        _matcherFactory = matcherFactory ?? throw new ArgumentNullException(nameof(matcherFactory));
        _logger = logger ?? throw new ArgumentNullException(nameof(logger));
        _diagnosticListener = diagnosticListener ?? throw new ArgumentNullException(nameof(diagnosticListener));
        _metrics = metrics;
        _next = next ?? throw new ArgumentNullException(nameof(next));
        _routeOptions = routeOptions.Value;
 
        // rootCompositeEndpointDataSource is a constructor parameter only so it always gets disposed by DI. This ensures that any
        // disposable EndpointDataSources also get disposed. _endpointDataSource is a component of rootCompositeEndpointDataSource.
        _ = rootCompositeEndpointDataSource;
        _endpointDataSource = new CompositeEndpointDataSource(endpointRouteBuilder.DataSources);
    }
 
    public Task Invoke(HttpContext httpContext)
    {
        // There's already an endpoint, skip matching completely
        var endpoint = httpContext.GetEndpoint();
        if (endpoint != null)
        {
            Log.MatchSkipped(_logger, endpoint);
            return _next(httpContext);
        }
 
        // There's an inherent race condition between waiting for init and accessing the matcher
        // this is OK because once `_matcher` is initialized, it will not be set to null again.
        var matcherTask = InitializeAsync();
        if (!matcherTask.IsCompletedSuccessfully)
        {
            return AwaitMatcher(this, httpContext, matcherTask);
        }
 
        var matchTask = matcherTask.Result.MatchAsync(httpContext);
        if (!matchTask.IsCompletedSuccessfully)
        {
            return AwaitMatch(this, httpContext, matchTask);
        }
 
        return SetRoutingAndContinue(httpContext);
 
        // Awaited fallbacks for when the Tasks do not synchronously complete
        static async Task AwaitMatcher(EndpointRoutingMiddleware middleware, HttpContext httpContext, Task<Matcher> matcherTask)
        {
            var matcher = await matcherTask;
            await matcher.MatchAsync(httpContext);
            await middleware.SetRoutingAndContinue(httpContext);
        }
 
        static async Task AwaitMatch(EndpointRoutingMiddleware middleware, HttpContext httpContext, Task matchTask)
        {
            await matchTask;
            await middleware.SetRoutingAndContinue(httpContext);
        }
    }
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private Task SetRoutingAndContinue(HttpContext httpContext)
    {
        // If there was no mutation of the endpoint then log failure
        var endpoint = httpContext.GetEndpoint();
        if (endpoint == null)
        {
            Log.MatchFailure(_logger);
            _metrics.MatchFailure();
        }
        else
        {
            // Raise an event if the route matched
            if (_diagnosticListener.IsEnabled() && _diagnosticListener.IsEnabled(DiagnosticsEndpointMatchedKey))
            {
                Write(_diagnosticListener, httpContext);
            }
 
            if (_logger.IsEnabled(LogLevel.Debug) || _metrics.MatchSuccessCounterEnabled)
            {
                var isFallback = endpoint.Metadata.GetMetadata<FallbackMetadata>() is not null;
 
                Log.MatchSuccess(_logger, endpoint);
 
                if (isFallback)
                {
                    Log.FallbackMatch(_logger, endpoint);
                }
 
                // It shouldn't be possible for a route to be matched via the route matcher and not have a route.
                // Just in case, add a special (missing) value as the route tag to metrics.
                var route = endpoint.Metadata.GetMetadata<IRouteDiagnosticsMetadata>()?.Route ?? "(missing)";
                _metrics.MatchSuccess(route, isFallback);
            }
 
            // Map RequestSizeLimitMetadata to IHttpMaxRequestBodySizeFeature if present on the endpoint.
            // We do this during endpoint routing to ensure that successive middlewares in the pipeline
            // can access the feature with the correct value.
            SetMaxRequestBodySize(httpContext);
 
            var shortCircuitMetadata = endpoint.Metadata.GetMetadata<ShortCircuitMetadata>();
            if (shortCircuitMetadata is not null)
            {
                return ExecuteShortCircuit(shortCircuitMetadata, endpoint, httpContext);
            }
        }
 
        return _next(httpContext);
 
        [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:UnrecognizedReflectionPattern",
            Justification = "The values being passed into Write are being consumed by the application already.")]
        static void Write(DiagnosticListener diagnosticListener, HttpContext httpContext)
        {
            // We're just going to send the HttpContext since it has all of the relevant information
            diagnosticListener.Write(DiagnosticsEndpointMatchedKey, httpContext);
        }
    }
 
    private Task ExecuteShortCircuit(ShortCircuitMetadata shortCircuitMetadata, Endpoint endpoint, HttpContext httpContext)
    {
        // This check should be kept in sync with the one in EndpointMiddleware
        if (!_routeOptions.SuppressCheckForUnhandledSecurityMetadata)
        {
            if (endpoint.Metadata.GetMetadata<IAuthorizeData>() is not null)
            {
                ThrowCannotShortCircuitAnAuthRouteException(endpoint);
            }
 
            if (endpoint.Metadata.GetMetadata<ICorsMetadata>() is not null)
            {
                ThrowCannotShortCircuitACorsRouteException(endpoint);
            }
 
            if (endpoint.Metadata.GetMetadata<IAntiforgeryMetadata>() is { RequiresValidation: true } &&
                httpContext.Request.Method is {} method &&
                HttpExtensions.IsValidHttpMethodForForm(method))
            {
                ThrowCannotShortCircuitAnAntiforgeryRouteException(endpoint);
            }
        }
 
        if (shortCircuitMetadata.StatusCode.HasValue)
        {
            httpContext.Response.StatusCode = shortCircuitMetadata.StatusCode.Value;
        }
 
        if (endpoint.RequestDelegate is not null)
        {
            if (!_logger.IsEnabled(LogLevel.Information))
            {
                // Avoid the AwaitRequestTask state machine allocation if logging is disabled.
                return endpoint.RequestDelegate(httpContext);
            }
 
            Log.ExecutingEndpoint(_logger, endpoint);
 
            try
            {
                var requestTask = endpoint.RequestDelegate(httpContext);
                if (!requestTask.IsCompletedSuccessfully)
                {
                    return AwaitRequestTask(endpoint, requestTask, _logger);
                }
            }
            catch
            {
                Log.ExecutedEndpoint(_logger, endpoint);
                throw;
            }
 
            Log.ExecutedEndpoint(_logger, endpoint);
 
            return Task.CompletedTask;
 
            static async Task AwaitRequestTask(Endpoint endpoint, Task requestTask, ILogger logger)
            {
                try
                {
                    await requestTask;
                }
                finally
                {
                    Log.ExecutedEndpoint(logger, endpoint);
                }
            }
 
        }
        else
        {
            Log.ShortCircuitedEndpoint(_logger, endpoint);
        }
        return Task.CompletedTask;
    }
 
    // Initialization is async to avoid blocking threads while reflection and things
    // of that nature take place.
    //
    // We've seen cases where startup is very slow if we  allow multiple threads to race
    // while initializing the set of endpoints/routes. Doing CPU intensive work is a
    // blocking operation if you have a low core count and enough work to do.
    private Task<Matcher> InitializeAsync()
    {
        var initializationTask = _initializationTask;
        if (initializationTask != null)
        {
            return initializationTask;
        }
 
        return InitializeCoreAsync();
    }
 
    private Task<Matcher> InitializeCoreAsync()
    {
        var initialization = new TaskCompletionSource<Matcher>(TaskCreationOptions.RunContinuationsAsynchronously);
        var initializationTask = Interlocked.CompareExchange(ref _initializationTask, initialization.Task, null);
        if (initializationTask != null)
        {
            // This thread lost the race, join the existing task.
            return initializationTask;
        }
 
        // This thread won the race, do the initialization.
        try
        {
            var matcher = _matcherFactory.CreateMatcher(_endpointDataSource);
 
            _initializationTask = Task.FromResult(matcher);
 
            // Complete the task, this will unblock any requests that came in while initializing.
            initialization.SetResult(matcher);
            return initialization.Task;
        }
        catch (Exception ex)
        {
            // Allow initialization to occur again. Since DataSources can change, it's possible
            // for the developer to correct the data causing the failure.
            _initializationTask = null;
 
            // Complete the task, this will throw for any requests that came in while initializing.
            initialization.SetException(ex);
            return initialization.Task;
        }
    }
 
    private static void ThrowCannotShortCircuitAnAuthRouteException(Endpoint endpoint)
    {
        throw new InvalidOperationException($"Endpoint {endpoint.DisplayName} contains authorization metadata, " +
            "but this endpoint is marked with short circuit and it will execute on Routing Middleware.");
    }
 
    private static void ThrowCannotShortCircuitACorsRouteException(Endpoint endpoint)
    {
        throw new InvalidOperationException($"Endpoint {endpoint.DisplayName} contains CORS metadata, " +
            "but this endpoint is marked with short circuit and it will execute on Routing Middleware.");
    }
 
    private static void ThrowCannotShortCircuitAnAntiforgeryRouteException(Endpoint endpoint)
    {
        throw new InvalidOperationException($"Endpoint {endpoint.DisplayName} contains anti-forgery metadata, " +
            "but this endpoint is marked with short circuit and it will execute on Routing Middleware.");
    }
 
    private void SetMaxRequestBodySize(HttpContext context)
    {
        var sizeLimitMetadata = context.GetEndpoint()?.Metadata?.GetMetadata<IRequestSizeLimitMetadata>();
        if (sizeLimitMetadata == null)
        {
            Log.RequestSizeLimitMetadataNotFound(_logger);
            return;
        }
 
        var maxRequestBodySizeFeature = context.Features.Get<IHttpMaxRequestBodySizeFeature>();
        if (maxRequestBodySizeFeature == null)
        {
            Log.RequestSizeFeatureNotFound(_logger);
        }
        else if (maxRequestBodySizeFeature.IsReadOnly)
        {
            Log.RequestSizeFeatureIsReadOnly(_logger);
        }
        else
        {
            var maxRequestBodySize = sizeLimitMetadata.MaxRequestBodySize;
            maxRequestBodySizeFeature.MaxRequestBodySize = maxRequestBodySize;
 
            if (maxRequestBodySize.HasValue)
            {
                Log.MaxRequestBodySizeSet(_logger,
                    maxRequestBodySize.Value.ToString(CultureInfo.InvariantCulture));
            }
            else
            {
                Log.MaxRequestBodySizeDisabled(_logger);
            }
        }
    }
 
    private static partial class Log
    {
        public static void MatchSuccess(ILogger logger, Endpoint endpoint)
            => MatchSuccess(logger, endpoint.DisplayName);
 
        [LoggerMessage(1, LogLevel.Debug, "Request matched endpoint '{EndpointName}'", EventName = "MatchSuccess")]
        private static partial void MatchSuccess(ILogger logger, string? endpointName);
 
        [LoggerMessage(2, LogLevel.Debug, "Request did not match any endpoints", EventName = "MatchFailure")]
        public static partial void MatchFailure(ILogger logger);
 
        public static void MatchSkipped(ILogger logger, Endpoint endpoint)
            => MatchingSkipped(logger, endpoint.DisplayName);
 
        [LoggerMessage(3, LogLevel.Debug, "Endpoint '{EndpointName}' already set, skipping route matching.", EventName = "MatchingSkipped")]
        private static partial void MatchingSkipped(ILogger logger, string? endpointName);
 
        [LoggerMessage(4, LogLevel.Information, "The endpoint '{EndpointName}' is being executed without running additional middleware.", EventName = "ExecutingEndpoint")]
        public static partial void ExecutingEndpoint(ILogger logger, Endpoint endpointName);
 
        [LoggerMessage(5, LogLevel.Information, "The endpoint '{EndpointName}' has been executed without running additional middleware.", EventName = "ExecutedEndpoint")]
        public static partial void ExecutedEndpoint(ILogger logger, Endpoint endpointName);
 
        [LoggerMessage(6, LogLevel.Information, "The endpoint '{EndpointName}' is being short circuited without running additional middleware or producing a response.", EventName = "ShortCircuitedEndpoint")]
        public static partial void ShortCircuitedEndpoint(ILogger logger, Endpoint endpointName);
 
        [LoggerMessage(7, LogLevel.Debug, "Matched endpoint '{EndpointName}' is a fallback endpoint.", EventName = "FallbackMatch")]
        public static partial void FallbackMatch(ILogger logger, Endpoint endpointName);
 
        [LoggerMessage(8, LogLevel.Trace, $"The endpoint does not specify the {nameof(IRequestSizeLimitMetadata)}.", EventName = "RequestSizeLimitMetadataNotFound")]
        public static partial void RequestSizeLimitMetadataNotFound(ILogger logger);
 
        [LoggerMessage(9, LogLevel.Warning, $"A request body size limit could not be applied. This server does not support the {nameof(IHttpMaxRequestBodySizeFeature)}.", EventName = "RequestSizeFeatureNotFound")]
        public static partial void RequestSizeFeatureNotFound(ILogger logger);
 
        [LoggerMessage(10, LogLevel.Warning, $"A request body size limit could not be applied. The {nameof(IHttpMaxRequestBodySizeFeature)} for the server is read-only.", EventName = "RequestSizeFeatureIsReadOnly")]
        public static partial void RequestSizeFeatureIsReadOnly(ILogger logger);
 
        [LoggerMessage(11, LogLevel.Debug, "The maximum request body size has been set to {RequestSize}.", EventName = "MaxRequestBodySizeSet")]
        public static partial void MaxRequestBodySizeSet(ILogger logger, string requestSize);
 
        [LoggerMessage(12, LogLevel.Debug, "The maximum request body size has been disabled.", EventName = "MaxRequestBodySizeDisabled")]
        public static partial void MaxRequestBodySizeDisabled(ILogger logger);
    }
}