File: ValidationEndpointFilterFactory.cs
Web Access
Project: src\src\Http\Routing\src\Microsoft.AspNetCore.Routing.csproj (Microsoft.AspNetCore.Routing)
#pragma warning disable ASP0029 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
 
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Net.Mime;
using System.Reflection;
using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.AspNetCore.Http.Metadata;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
 
namespace Microsoft.AspNetCore.Http.Validation;
 
internal static class ValidationEndpointFilterFactory
{
    // A small struct to hold the validatable parameter details to avoid allocating arrays for parameters that don't need validation
    private readonly record struct ValidatableParameterEntry(int Index, IValidatableInfo Parameter, string DisplayName);
 
    public static EndpointFilterDelegate Create(EndpointFilterFactoryContext context, EndpointFilterDelegate next)
    {
        var parameters = context.MethodInfo.GetParameters();
        var options = context.ApplicationServices.GetService<IOptions<ValidationOptions>>()?.Value;
        if (options is null || options.Resolvers.Count == 0)
        {
            return next;
        }
 
        var serviceProviderIsService = context.ApplicationServices.GetService<IServiceProviderIsService>();
 
        // Use a list to only store validatable parameters instead of arrays for all parameters
        List<ValidatableParameterEntry>? validatableParameters = null;
 
        for (var i = 0; i < parameters.Length; i++)
        {
            // Ignore parameters that are resolved from the DI container.
            if (IsServiceParameter(parameters[i], serviceProviderIsService))
            {
                continue;
            }
 
            if (options.TryGetValidatableParameterInfo(parameters[i], out var validatableParameter))
            {
                validatableParameters ??= [];
                validatableParameters.Add(new ValidatableParameterEntry(
                    i,
                    validatableParameter,
                    GetDisplayName(parameters[i])));
            }
        }
 
        if (validatableParameters is null || validatableParameters.Count == 0)
        {
            return next;
        }
 
        return async (context) =>
        {
            ValidateContext? validateContext = null;
 
            foreach (var entry in validatableParameters)
            {
                if (entry.Index >= context.Arguments.Count)
                {
                    break;
                }
 
                var argument = context.Arguments[entry.Index];
                if (argument is null)
                {
                    continue;
                }
 
                var validationContext = new ValidationContext(argument, entry.DisplayName, context.HttpContext.RequestServices, items: null);
 
                if (validateContext == null)
                {
                    validateContext = new ValidateContext
                    {
                        ValidationOptions = options,
                        ValidationContext = validationContext
                    };
                }
                else
                {
                    validateContext.ValidationContext = validationContext;
                }
 
                await entry.Parameter.ValidateAsync(argument, validateContext, context.HttpContext.RequestAborted);
            }
 
            if (validateContext is { ValidationErrors.Count: > 0 })
            {
                context.HttpContext.Response.StatusCode = StatusCodes.Status400BadRequest;
 
                var problemDetails = new HttpValidationProblemDetails(validateContext.ValidationErrors);
 
                var problemDetailsService = context.HttpContext.RequestServices.GetService<IProblemDetailsService>();
                if (problemDetailsService is not null)
                {
                    if (await problemDetailsService.TryWriteAsync(new()
                    {
                        HttpContext = context.HttpContext,
                        ProblemDetails = problemDetails
                    }))
                    {
                        // We need to prevent further execution, because the actual
                        // ProblemDetails response has already been written by ProblemDetailsService.
                        return EmptyHttpResult.Instance;
                    }
                }
 
                // Fallback to the default implementation.
                context.HttpContext.Response.ContentType = MediaTypeNames.Application.ProblemJson;
                return problemDetails;
            }
 
            return await next(context);
        };
    }
 
    private static bool IsServiceParameter(ParameterInfo parameterInfo, IServiceProviderIsService? isService)
        => HasFromServicesAttribute(parameterInfo) ||
           (isService?.IsService(parameterInfo.ParameterType) == true);
 
    private static bool HasFromServicesAttribute(ParameterInfo parameterInfo)
        => parameterInfo.CustomAttributes.OfType<IFromServiceMetadata>().Any();
 
    private static string GetDisplayName(ParameterInfo parameterInfo)
    {
        var displayAttribute = parameterInfo.GetCustomAttribute<DisplayAttribute>();
        if (displayAttribute != null)
        {
            return displayAttribute.Name ?? parameterInfo.Name!;
        }
 
        return parameterInfo.Name!;
    }
}