File: RequestDelegateGenerator.cs
Web Access
Project: src\src\Http\Http.Extensions\gen\Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj (Microsoft.AspNetCore.Http.RequestDelegateGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Immutable;
using System.Globalization;
using System.IO;
using System.Linq;
using Microsoft.AspNetCore.Analyzers.Infrastructure;
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel;
using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel.Emitters;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
 
namespace Microsoft.AspNetCore.Http.RequestDelegateGenerator;
 
[Generator]
public sealed class RequestDelegateGenerator : IIncrementalGenerator
{
    public void Initialize(IncrementalGeneratorInitializationContext context)
    {
        var endpointsWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider(
            predicate: static (node, _) => node.TryGetMapMethodName(out var method) && InvocationOperationExtensions.KnownMethods.Contains(method),
            transform: static (context, token) =>
            {
                var operation = context.SemanticModel.GetOperation(context.Node, token);
                var wellKnownTypes = WellKnownTypes.GetOrCreate(context.SemanticModel.Compilation);
                if (operation.IsValidOperation(wellKnownTypes, out var invocationOperation))
                {
                    return new Endpoint(invocationOperation, wellKnownTypes, context.SemanticModel);
                }
                return null;
            })
            .Where(static endpoint => endpoint != null)
            .Select((endpoint, _) =>
            {
                AnalyzerDebug.Assert(endpoint != null, "Invalid endpoints should not be processed.");
                return endpoint;
            })
            .WithTrackingName(GeneratorSteps.EndpointModelStep);
 
        context.RegisterSourceOutput(endpointsWithDiagnostics, (context, endpoint) =>
        {
            foreach (var diagnostic in endpoint.Diagnostics)
            {
                context.ReportDiagnostic(diagnostic);
            }
        });
 
        var endpoints = endpointsWithDiagnostics
            .Where(endpoint => endpoint.Diagnostics.Count == 0)
            .WithTrackingName(GeneratorSteps.EndpointsWithoutDiagnosicsStep);
 
        var interceptorDefinitions = endpoints
            .GroupWith((endpoint) => endpoint.InterceptableLocation, EndpointDelegateComparer.Instance)
            .Select((endpointWithLocations, _) =>
            {
                var endpoint = endpointWithLocations.Source;
                using var stringWriter = new StringWriter(CultureInfo.InvariantCulture);
                using var codeWriter = new CodeWriter(stringWriter, baseIndent: 2);
                foreach (var location in endpointWithLocations.Elements)
                {
#pragma warning disable RSEXPERIMENTAL002 // Experimental interceptable location API
                    codeWriter.WriteLine(location.GetInterceptsLocationAttributeSyntax());
#pragma warning restore RSEXPERIMENTAL002
                }
                codeWriter.WriteLine($"internal static RouteHandlerBuilder {endpoint.HttpMethod}{endpointWithLocations.Index}(");
                codeWriter.Indent++;
                codeWriter.WriteLine("this IEndpointRouteBuilder endpoints,");
                // MapFallback overloads that only take a delegate do not need a pattern argument
                if (endpoint.HttpMethod != "MapFallback" || endpoint.Operation.Arguments.Length != 2)
                {
                    codeWriter.WriteLine(@"[StringSyntax(""Route"")] string pattern,");
                }
                // MapMethods overloads define an additional `httpMethods` parameter
                if (endpoint.HttpMethod == "MapMethods")
                {
                    codeWriter.WriteLine("IEnumerable<string> httpMethods,");
                }
                codeWriter.WriteLine("Delegate handler)");
                codeWriter.Indent--;
                codeWriter.StartBlock();
                codeWriter.WriteLine("MetadataPopulator populateMetadata = (methodInfo, options) =>");
                codeWriter.StartBlock();
                codeWriter.WriteLine(@"Debug.Assert(options != null, ""RequestDelegateFactoryOptions not found."");");
                codeWriter.WriteLine(@"Debug.Assert(options.EndpointBuilder != null, ""EndpointBuilder not found."");");
                codeWriter.WriteLine($"options.EndpointBuilder.Metadata.Add(new {RequestDelegateGeneratorSources.GeneratedCodeConstructor});");
                endpoint.EmitEndpointMetadataPopulation(codeWriter);
                codeWriter.WriteLine("return new RequestDelegateMetadataResult { EndpointMetadata = options.EndpointBuilder.Metadata.AsReadOnly() };");
                codeWriter.EndBlockWithSemicolon();
                codeWriter.WriteLine("RequestDelegateFactoryFunc createRequestDelegate = (del, options, inferredMetadataResult) =>");
                codeWriter.StartBlock();
                codeWriter.WriteLine(@"Debug.Assert(options != null, ""RequestDelegateFactoryOptions not found."");");
                codeWriter.WriteLine(@"Debug.Assert(options.EndpointBuilder != null, ""EndpointBuilder not found."");");
                codeWriter.WriteLine(@"Debug.Assert(options.EndpointBuilder.ApplicationServices != null, ""ApplicationServices not found."");");
                codeWriter.WriteLine(@"Debug.Assert(options.EndpointBuilder.FilterFactories != null, ""FilterFactories not found."");");
                codeWriter.WriteLine($"var handler = Cast(del, {endpoint.EmitHandlerDelegateType()} => throw null!);");
                codeWriter.WriteLine("EndpointFilterDelegate? filteredInvocation = null;");
                codeWriter.WriteLine("var serviceProvider = options.ServiceProvider ?? options.EndpointBuilder.ApplicationServices;");
                endpoint.EmitLoggingPreamble(codeWriter);
                endpoint.EmitJsonPreparation(codeWriter);
                endpoint.EmitRouteOrQueryResolver(codeWriter);
                endpoint.EmitJsonBodyOrServiceResolver(codeWriter);
                if (endpoint.NeedsParameterArray)
                {
                    codeWriter.WriteLine("var parameters = del.Method.GetParameters();");
                }
                codeWriter.WriteLineNoTabs(string.Empty);
                codeWriter.WriteLine("if (options.EndpointBuilder.FilterFactories.Count > 0)");
                codeWriter.StartBlock();
                codeWriter.WriteLine(endpoint.Response?.IsAwaitable == true
                    ? "filteredInvocation = GeneratedRouteBuilderExtensionsCore.BuildFilterDelegate(async ic =>"
                    : "filteredInvocation = GeneratedRouteBuilderExtensionsCore.BuildFilterDelegate(ic =>");
                codeWriter.StartBlock();
                codeWriter.WriteLine("if (ic.HttpContext.Response.StatusCode == 400)");
                codeWriter.StartBlock();
                codeWriter.WriteLine(endpoint.Response?.IsAwaitable == true
                    ? "return (object?)Results.Empty;"
                    : "return ValueTask.FromResult<object?>(Results.Empty);");
                codeWriter.EndBlock();
                endpoint.EmitFilteredInvocation(codeWriter);
                codeWriter.EndBlockWithComma();
                codeWriter.WriteLine("options.EndpointBuilder,");
                codeWriter.WriteLine("handler.Method);");
                codeWriter.EndBlock();
                codeWriter.WriteLineNoTabs(string.Empty);
                endpoint.EmitRequestHandler(codeWriter);
                codeWriter.WriteLineNoTabs(string.Empty);
                endpoint.EmitFilteredRequestHandler(codeWriter);
                codeWriter.WriteLineNoTabs(string.Empty);
                codeWriter.WriteLine("RequestDelegate targetDelegate = filteredInvocation is null ? RequestHandler : RequestHandlerFiltered;");
                codeWriter.WriteLine("var metadata = inferredMetadataResult?.EndpointMetadata ?? ReadOnlyCollection<object>.Empty;");
                codeWriter.WriteLine("return new RequestDelegateResult(targetDelegate, metadata);");
                codeWriter.EndBlockWithSemicolon();
                codeWriter.WriteLine($"var castHandler = Cast(handler, {endpoint.EmitHandlerDelegateType()} => throw null!);");
                codeWriter.WriteLine("return MapCore(");
                codeWriter.Indent++;
                codeWriter.WriteLine("endpoints,");
                // For `MapFallback` overloads that only take a delegate, provide the assumed default
                // Otherwise, pass the pattern provided from the MapX invocation
                if (endpoint.HttpMethod != "MapFallback" && endpoint.Operation.Arguments.Length != 2)
                {
                    codeWriter.WriteLine("pattern,");
                }
                else
                {
                    codeWriter.WriteLine($"{SymbolDisplay.FormatLiteral("{*path:nonfile}", true)},");
                }
                codeWriter.WriteLine("handler,");
                codeWriter.WriteLine($"{endpoint.EmitVerb()},");
                codeWriter.WriteLine("populateMetadata,");
                codeWriter.WriteLine("createRequestDelegate,");
                codeWriter.WriteLine("castHandler.Method);");
                codeWriter.Indent--;
                codeWriter.EndBlock();
                return stringWriter.ToString();
            });
 
        var httpVerbs = endpoints
            .Collect()
            .Select((endpoints, _) =>
            {
                return endpoints
                    .Distinct(EndpointHttpMethodComparer.Instance)
                    .Select(endpoint => endpoint.EmitterContext.HttpMethod!)
                    .Where(verb => verb is not null)
                    .ToImmutableHashSet();
            });
 
        var endpointHelpers = endpoints
            .Collect()
            .Select((endpoints, _) =>
            {
                var hasJsonBodyOrService = endpoints.Any(endpoint => endpoint.EmitterContext.HasJsonBodyOrService);
                var hasJsonBodyOrQuery = endpoints.Any(endpoint => endpoint.EmitterContext.HasJsonBodyOrQuery);
                var hasJsonBody = endpoints.Any(endpoint => endpoint.EmitterContext.HasJsonBody);
                var hasFormBody = endpoints.Any(endpoint => endpoint.EmitterContext.HasFormBody);
                var hasRouteOrQuery = endpoints.Any(endpoint => endpoint.EmitterContext.HasRouteOrQuery);
                var hasBindAsync = endpoints.Any(endpoint => endpoint.EmitterContext.HasBindAsync);
                var hasParsable = endpoints.Any(endpoint => endpoint.EmitterContext.HasParsable);
                var hasEndpointMetadataProvider = endpoints.Any(endpoint => endpoint.EmitterContext.HasEndpointMetadataProvider);
                var hasEndpointParameterMetadataProvider = endpoints.Any(endpoint => endpoint.EmitterContext.HasEndpointParameterMetadataProvider);
                var hasIResult = endpoints.Any(endpoint => endpoint.Response?.IsIResult == true);
 
                using var stringWriter = new StringWriter(CultureInfo.InvariantCulture);
                using var codeWriter = new CodeWriter(stringWriter, baseIndent: 0);
 
                if (hasRouteOrQuery)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.ResolveFromRouteOrQueryMethod);
                }
 
                if (hasJsonBody || hasJsonBodyOrService || hasJsonBodyOrQuery)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.TryResolveBodyAsyncMethod);
                }
 
                if (hasFormBody)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.TryResolveFormAsyncMethod);
                }
 
                if (hasBindAsync)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.BindAsyncMethod);
                }
 
                if (hasJsonBodyOrService)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.ResolveJsonBodyOrServiceMethod);
                }
 
                if (hasParsable)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.TryParseExplicitMethod);
                }
 
                if (hasIResult)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.ExecuteAsyncExplicitMethod);
                }
 
                if (hasEndpointMetadataProvider)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.PopulateEndpointMetadataMethod);
                }
 
                if (hasEndpointParameterMetadataProvider)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.PopulateEndpointParameterMetadataMethod);
                }
 
                return stringWriter.ToString();
            });
 
        var helperTypes = endpoints
            .Collect()
            .Select((endpoints, _) =>
            {
                var hasFormBody = endpoints.Any(endpoint => endpoint.EmitterContext.HasFormBody);
                var hasJsonBody = endpoints.Any(endpoint => endpoint.EmitterContext.HasJsonBody || endpoint.EmitterContext.HasJsonBodyOrService || endpoint.EmitterContext.HasJsonBodyOrQuery);
                var hasResponseMetadata = endpoints.Any(endpoint => endpoint.EmitterContext.HasResponseMetadata);
                var requiresPropertyAsParameterInfo = endpoints.Any(endpoint => endpoint.EmitterContext.RequiresPropertyAsParameterInfo);
                var requiresParameterBindingMetadataClass = endpoints.Any(endpoint => endpoint.EmitterContext.RequiresParameterBindingMetadataClass);
 
                using var stringWriter = new StringWriter(CultureInfo.InvariantCulture);
                using var codeWriter = new CodeWriter(stringWriter, baseIndent: 0);
 
                if (hasFormBody)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.AntiforgeryMetadataType);
                }
 
                if (hasFormBody || hasJsonBody || hasResponseMetadata)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.ContentTypeConstantsType);
                }
 
                if (requiresPropertyAsParameterInfo)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.PropertyAsParameterInfoClass);
                }
 
                if (requiresParameterBindingMetadataClass)
                {
                    codeWriter.WriteLine(RequestDelegateGeneratorSources.ParameterBindingMetadataClass);
                }
 
                return stringWriter.ToString();
            });
 
        var endpointsAndHelpers = interceptorDefinitions.Collect().Combine(endpointHelpers).Combine(httpVerbs).Combine(helperTypes);
 
        context.RegisterSourceOutput(endpointsAndHelpers, (context, sources) =>
        {
            var (((endpointsCode, helperMethods), httpVerbs), helperTypes) = sources;
            if (endpointsCode.IsDefaultOrEmpty)
            {
                return;
            }
            using var stringWriter = new StringWriter(CultureInfo.InvariantCulture);
            using var codeWriter = new CodeWriter(stringWriter, baseIndent: 0);
            foreach (var endpoint in endpointsCode)
            {
                codeWriter.WriteLine(endpoint);
            }
            var code = RequestDelegateGeneratorSources.GetGeneratedRouteBuilderExtensionsSource(
                endpoints: stringWriter.ToString(),
                helperMethods: helperMethods ?? string.Empty,
                helperTypes: helperTypes ?? string.Empty,
                verbs: httpVerbs);
 
            context.AddSource("GeneratedRouteBuilderExtensions.g.cs", code);
        });
    }
}