File: Emitters\ValidationsGenerator.Emitter.cs
Web Access
Project: src\src\Http\Http.Extensions\gen\Microsoft.AspNetCore.Http.ValidationsGenerator\Microsoft.AspNetCore.Http.ValidationsGenerator.csproj (Microsoft.AspNetCore.Http.ValidationsGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;
using System.Text;
using Microsoft.CodeAnalysis.CSharp;
using System.IO;
 
namespace Microsoft.AspNetCore.Http.ValidationsGenerator;
 
public sealed partial class ValidationsGenerator : IIncrementalGenerator
{
    public static string GeneratedCodeConstructor => $@"global::System.CodeDom.Compiler.GeneratedCodeAttribute(""{typeof(ValidationsGenerator).Assembly.FullName}"", ""{typeof(ValidationsGenerator).Assembly.GetName().Version}"")";
    public static string GeneratedCodeAttribute => $"[{GeneratedCodeConstructor}]";
 
    internal static void Emit(SourceProductionContext context, (InterceptableLocation? AddValidation, ImmutableArray<ValidatableType> ValidatableTypes) emitInputs)
    {
        if (emitInputs.AddValidation is null)
        {
            // Avoid generating code if no AddValidation call was found.
            return;
        }
        var source = Emit(emitInputs.AddValidation, emitInputs.ValidatableTypes);
        context.AddSource("ValidatableInfoResolver.g.cs", SourceText.From(source, Encoding.UTF8));
    }
 
    private static string Emit(InterceptableLocation addValidation, ImmutableArray<ValidatableType> validatableTypes) => $$"""
#nullable enable annotations
//------------------------------------------------------------------------------
// <auto-generated>
//     This code was generated by a tool.
//
//     Changes to this file may cause incorrect behavior and will be lost if
//     the code is regenerated.
// </auto-generated>
//------------------------------------------------------------------------------
#nullable enable
 
namespace System.Runtime.CompilerServices
{
    {{GeneratedCodeAttribute}}
    [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
    file sealed class InterceptsLocationAttribute : System.Attribute
    {
        public InterceptsLocationAttribute(int version, string data)
        {
        }
    }
}
 
namespace Microsoft.AspNetCore.Http.Validation.Generated
{
    {{GeneratedCodeAttribute}}
    file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo
    {
        public GeneratedValidatablePropertyInfo(
            global::System.Type containingType,
            global::System.Type propertyType,
            string name,
            string displayName) : base(containingType, propertyType, name, displayName)
        {
            ContainingType = containingType;
            Name = name;
        }
 
        internal global::System.Type ContainingType { get; }
        internal string Name { get; }
 
        protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes()
            => ValidationAttributeCache.GetValidationAttributes(ContainingType, Name);
    }
 
    {{GeneratedCodeAttribute}}
    file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo
    {
        public GeneratedValidatableTypeInfo(
            global::System.Type type,
            ValidatablePropertyInfo[] members) : base(type, members) { }
    }
 
    {{GeneratedCodeAttribute}}
    file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver
    {
        public bool TryGetValidatableTypeInfo(global::System.Type type, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo)
        {
            validatableInfo = null;
{{EmitTypeChecks(validatableTypes)}}
            return false;
        }
 
        // No-ops, rely on runtime code for ParameterInfo-based resolution
        public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, [global::System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo)
        {
            validatableInfo = null;
            return false;
        }
 
{{EmitCreateMethods(validatableTypes)}}
    }
 
    {{GeneratedCodeAttribute}}
    file static class GeneratedServiceCollectionExtensions
    {
        {{addValidation.GetInterceptsLocationAttributeSyntax()}}
        public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action<ValidationOptions>? configureOptions = null)
        {
            // Use non-extension method to avoid infinite recursion.
            return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options =>
            {
                options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver());
                if (configureOptions is not null)
                {
                    configureOptions(options);
                }
            });
        }
    }
 
    {{GeneratedCodeAttribute}}
    file static class ValidationAttributeCache
    {
        private sealed record CacheKey(global::System.Type ContainingType, string PropertyName);
        private static readonly global::System.Collections.Concurrent.ConcurrentDictionary<CacheKey, global::System.ComponentModel.DataAnnotations.ValidationAttribute[]> _cache = new();
 
        public static global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes(
            global::System.Type containingType,
            string propertyName)
        {
            var key = new CacheKey(containingType, propertyName);
            return _cache.GetOrAdd(key, static k =>
            {
                var property = k.ContainingType.GetProperty(k.PropertyName);
                if (property == null)
                {
                    return [];
                }
 
                return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(property, inherit: true)];
            });
        }
    }
}
""";
 
    private static string EmitTypeChecks(ImmutableArray<ValidatableType> validatableTypes)
    {
        var sw = new StringWriter();
        var cw = new CodeWriter(sw, baseIndent: 3);
        foreach (var validatableType in validatableTypes)
        {
            var typeName = validatableType.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
            cw.WriteLine($"if (type == typeof({typeName}))");
            cw.StartBlock();
            cw.WriteLine($"validatableInfo = Create{SanitizeTypeName(validatableType.Type.MetadataName)}();");
            cw.WriteLine("return true;");
            cw.EndBlock();
        }
        return sw.ToString();
    }
 
    private static string EmitCreateMethods(ImmutableArray<ValidatableType> validatableTypes)
    {
        var sw = new StringWriter();
        var cw = new CodeWriter(sw, baseIndent: 2);
        foreach (var validatableType in validatableTypes)
        {
            cw.WriteLine($@"private ValidatableTypeInfo Create{SanitizeTypeName(validatableType.Type.MetadataName)}()");
            cw.StartBlock();
            cw.WriteLine("return new GeneratedValidatableTypeInfo(");
            cw.Indent++;
            cw.WriteLine($"type: typeof({validatableType.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}),");
            if (validatableType.Members.IsDefaultOrEmpty)
            {
                cw.WriteLine("members: []");
            }
            else
            {
                cw.WriteLine("members: [");
                cw.Indent++;
                foreach (var member in validatableType.Members)
                {
                    EmitValidatableMemberForCreate(member, cw);
                }
                cw.Indent--;
                cw.WriteLine("]");
            }
            cw.Indent--;
            cw.WriteLine(");");
            cw.EndBlock();
        }
        return sw.ToString();
    }
 
    private static void EmitValidatableMemberForCreate(ValidatableProperty member, CodeWriter cw)
    {
        cw.WriteLine("new GeneratedValidatablePropertyInfo(");
        cw.Indent++;
        cw.WriteLine($"containingType: typeof({member.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}),");
        cw.WriteLine($"propertyType: typeof({member.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}),");
        cw.WriteLine($"name: \"{member.Name}\",");
        cw.WriteLine($"displayName: \"{member.DisplayName}\"");
        cw.Indent--;
        cw.WriteLine("),");
    }
 
    private static string SanitizeTypeName(string typeName)
    {
        // Replace invalid characters with underscores and remove generic notation
        return typeName
            .Replace(".", "_")
            .Replace("<", "_")
            .Replace(">", "_")
            .Replace(",", "_")
            .Replace(" ", "_");
    }
}