File: XmlCommentGenerator.Emitter.cs
Web Access
Project: src\src\OpenApi\gen\Microsoft.AspNetCore.OpenApi.SourceGenerators.csproj (Microsoft.AspNetCore.OpenApi.SourceGenerators)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.AspNetCore.OpenApi.SourceGenerators.Xml;
using System.Threading;
using System.Linq;
 
namespace Microsoft.AspNetCore.OpenApi.SourceGenerators;
 
public sealed partial class XmlCommentGenerator : IIncrementalGenerator
{
    public static string GeneratedCodeConstructor => $@"System.CodeDom.Compiler.GeneratedCodeAttribute(""{typeof(XmlCommentGenerator).Assembly.FullName}"", ""{typeof(XmlCommentGenerator).Assembly.GetName().Version}"")";
    public static string GeneratedCodeAttribute => $"[{GeneratedCodeConstructor}]";
 
    internal static string GenerateXmlCommentSupportSource(string commentsFromXmlFile, string? commentsFromCompilation, ImmutableArray<(AddOpenApiInvocation Source, int Index, ImmutableArray<InterceptableLocation?> Elements)> groupedAddOpenApiInvocations) => $$"""
//------------------------------------------------------------------------------
// <auto-generated>
//     This code was generated by a tool.
//
//     Changes to this file may cause incorrect behavior and will be lost if
//     the code is regenerated.
// </auto-generated>
//------------------------------------------------------------------------------
#nullable enable
 
namespace System.Runtime.CompilerServices
{
    {{GeneratedCodeAttribute}}
    [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
    file sealed class InterceptsLocationAttribute : System.Attribute
    {
        public InterceptsLocationAttribute(int version, string data)
        {
        }
    }
}
 
namespace Microsoft.AspNetCore.OpenApi.Generated
{
    using System;
    using System.Collections.Generic;
    using System.Diagnostics.CodeAnalysis;
    using System.Linq;
    using System.Reflection;
    using System.Text.Json;
    using System.Text.Json.Nodes;
    using System.Threading;
    using System.Threading.Tasks;
    using Microsoft.AspNetCore.OpenApi;
    using Microsoft.AspNetCore.Mvc.Controllers;
    using Microsoft.Extensions.DependencyInjection;
    using Microsoft.OpenApi.Models;
    using Microsoft.OpenApi.Models.References;
    using Microsoft.OpenApi.Any;
 
    {{GeneratedCodeAttribute}}
    file record XmlComment(
        string? Summary,
        string? Description,
        string? Remarks,
        string? Returns,
        string? Value,
        bool Deprecated,
        List<string>? Examples,
        List<XmlParameterComment>? Parameters,
        List<XmlResponseComment>? Responses);
 
    {{GeneratedCodeAttribute}}
    file record XmlParameterComment(string? Name, string? Description, string? Example, bool Deprecated);
 
    {{GeneratedCodeAttribute}}
    file record XmlResponseComment(string Code, string? Description, string? Example);
 
    {{GeneratedCodeAttribute}}
    file sealed record MemberKey(
        Type? DeclaringType,
        MemberType MemberKind,
        string? Name,
        Type? ReturnType,
        Type[]? Parameters) : IEquatable<MemberKey>
    {
        public bool Equals(MemberKey? other)
        {
            if (other is null) return false;
 
            // Check member kind
            if (MemberKind != other.MemberKind) return false;
 
            // Check declaring type, handling generic types
            if (!TypesEqual(DeclaringType, other.DeclaringType)) return false;
 
            // Check name
            if (Name != other.Name) return false;
 
            // For methods, check return type and parameters
            if (MemberKind == MemberType.Method)
            {
                if (!TypesEqual(ReturnType, other.ReturnType)) return false;
                if (Parameters is null || other.Parameters is null) return false;
                if (Parameters.Length != other.Parameters.Length) return false;
 
                for (int i = 0; i < Parameters.Length; i++)
                {
                    if (!TypesEqual(Parameters[i], other.Parameters[i])) return false;
                }
            }
 
            return true;
        }
 
        private static bool TypesEqual(Type? type1, Type? type2)
        {
            if (type1 == type2) return true;
            if (type1 == null || type2 == null) return false;
 
            if (type1.IsGenericType && type2.IsGenericType)
            {
                return type1.GetGenericTypeDefinition() == type2.GetGenericTypeDefinition();
            }
 
            return type1 == type2;
        }
 
        public override int GetHashCode()
        {
            var hash = new HashCode();
            hash.Add(GetTypeHashCode(DeclaringType));
            hash.Add(MemberKind);
            hash.Add(Name);
 
            if (MemberKind == MemberType.Method)
            {
                hash.Add(GetTypeHashCode(ReturnType));
                if (Parameters is not null)
                {
                    foreach (var param in Parameters)
                    {
                        hash.Add(GetTypeHashCode(param));
                    }
                }
            }
 
            return hash.ToHashCode();
        }
 
        private static int GetTypeHashCode(Type? type)
        {
            if (type == null) return 0;
            return type.IsGenericType ? type.GetGenericTypeDefinition().GetHashCode() : type.GetHashCode();
        }
 
        public static MemberKey FromMethodInfo(MethodInfo method)
        {
            return new MemberKey(
                method.DeclaringType,
                MemberType.Method,
                method.Name,
                method.ReturnType.IsGenericParameter ? typeof(object) : method.ReturnType,
                method.GetParameters().Select(p => p.ParameterType.IsGenericParameter ? typeof(object) : p.ParameterType).ToArray());
        }
 
        public static MemberKey FromPropertyInfo(PropertyInfo property)
        {
            return new MemberKey(
                property.DeclaringType,
                MemberType.Property,
                property.Name,
                null,
                null);
        }
 
        public static MemberKey FromTypeInfo(Type type)
        {
            return new MemberKey(
                type,
                MemberType.Type,
                null,
                null,
                null);
        }
    }
 
    file enum MemberType
    {
        Type,
        Property,
        Method
    }
 
    {{GeneratedCodeAttribute}}
    file static class XmlCommentCache
    {
        private static Dictionary<MemberKey, XmlComment>? _cache;
        public static Dictionary<MemberKey, XmlComment> Cache => _cache ??= GenerateCacheEntries();
 
        private static Dictionary<MemberKey, XmlComment> GenerateCacheEntries()
        {
            var _cache = new Dictionary<MemberKey, XmlComment>();
{{commentsFromXmlFile}}
{{commentsFromCompilation}}
            return _cache;
        }
 
        internal static bool TryGetXmlComment(Type? type, MethodInfo? methodInfo, [NotNullWhen(true)] out XmlComment? xmlComment)
        {
            if (methodInfo is null)
            {
                return Cache.TryGetValue(new MemberKey(type, MemberType.Property, null, null, null), out xmlComment);
            }
 
            return Cache.TryGetValue(MemberKey.FromMethodInfo(methodInfo), out xmlComment);
        }
 
        internal static bool TryGetXmlComment(Type? type, string? memberName, [NotNullWhen(true)] out XmlComment? xmlComment)
        {
            return Cache.TryGetValue(new MemberKey(type, memberName is null ? MemberType.Type : MemberType.Property, memberName, null, null), out xmlComment);
        }
    }
 
    {{GeneratedCodeAttribute}}
    file class XmlCommentOperationTransformer : IOpenApiOperationTransformer
    {
        public Task TransformAsync(OpenApiOperation operation, OpenApiOperationTransformerContext context, CancellationToken cancellationToken)
        {
            var methodInfo = context.Description.ActionDescriptor is ControllerActionDescriptor controllerActionDescriptor
                ? controllerActionDescriptor.MethodInfo
                : context.Description.ActionDescriptor.EndpointMetadata.OfType<MethodInfo>().SingleOrDefault();
 
            if (methodInfo is null)
            {
                return Task.CompletedTask;
            }
            if (XmlCommentCache.TryGetXmlComment(methodInfo.DeclaringType, methodInfo, out var methodComment))
            {
                if (methodComment.Summary is { } summary)
                {
                    operation.Summary = summary;
                }
                if (methodComment.Description is { } description)
                {
                    operation.Description = description;
                }
                if (methodComment.Remarks is { } remarks)
                {
                    operation.Description = remarks;
                }
                if (methodComment.Parameters is { Count: > 0})
                {
                    foreach (var parameterComment in methodComment.Parameters)
                    {
                        var parameterInfo = methodInfo.GetParameters().SingleOrDefault(info => info.Name == parameterComment.Name);
                        var operationParameter = operation.Parameters?.SingleOrDefault(parameter => parameter.Name == parameterComment.Name);
                        if (operationParameter is not null)
                        {
                            var targetOperationParameter = operationParameter is OpenApiParameterReference reference
                                ? reference.Target
                                : (OpenApiParameter)operationParameter;
                            targetOperationParameter.Description = parameterComment.Description;
                            if (parameterComment.Example is { } jsonString)
                            {
                                targetOperationParameter.Example = jsonString.Parse();
                            }
                            targetOperationParameter.Deprecated = parameterComment.Deprecated;
                        }
                        else
                        {
                            var requestBody = operation.RequestBody;
                            if (requestBody is not null)
                            {
                                requestBody.Description = parameterComment.Description;
                                if (parameterComment.Example is { } jsonString)
                                {
                                    foreach (var mediaType in requestBody.Content.Values)
                                    {
                                        mediaType.Example = jsonString.Parse();
                                    }
                                }
                            }
                        }
                    }
                }
                if (methodComment.Responses is { Count: > 0} && operation.Responses is { Count: > 0 })
                {
                    foreach (var response in operation.Responses)
                    {
                        var responseComment = methodComment.Responses.SingleOrDefault(xmlResponse => xmlResponse.Code == response.Key);
                        if (responseComment is not null)
                        {
                            response.Value.Description = responseComment.Description;
                        }
                    }
                }
            }
 
            return Task.CompletedTask;
        }
    }
 
    {{GeneratedCodeAttribute}}
    file class XmlCommentSchemaTransformer : IOpenApiSchemaTransformer
    {
        public Task TransformAsync(OpenApiSchema schema, OpenApiSchemaTransformerContext context, CancellationToken cancellationToken)
        {
            if (context.JsonPropertyInfo is { AttributeProvider: PropertyInfo propertyInfo })
            {
                if (XmlCommentCache.TryGetXmlComment(propertyInfo.DeclaringType, propertyInfo.Name, out var propertyComment))
                {
                    schema.Description = propertyComment.Value ?? propertyComment.Returns ?? propertyComment.Summary;
                    if (propertyComment.Examples?.FirstOrDefault() is { } jsonString)
                    {
                        schema.Example = jsonString.Parse();
                    }
                }
            }
            if (XmlCommentCache.TryGetXmlComment(context.JsonTypeInfo.Type, (string?)null, out var typeComment))
            {
                schema.Description = typeComment.Summary;
                if (typeComment.Examples?.FirstOrDefault() is { } jsonString)
                {
                    schema.Example = jsonString.Parse();
                }
            }
            return Task.CompletedTask;
        }
    }
 
    file static class JsonNodeExtensions
    {
        public static JsonNode? Parse(this string? json)
        {
            if (json is null)
            {
                return null;
            }
 
            try
            {
                return JsonNode.Parse(json);
            }
            catch (JsonException)
            {
                try
                {
                    // If parsing fails, try wrapping in quotes to make it a valid JSON string
                    return JsonNode.Parse($"\"{json.Replace("\"", "\\\"")}\"");
                }
                catch (JsonException)
                {
                    return null;
                }
            }
        }
    }
 
    {{GeneratedCodeAttribute}}
    file static class GeneratedServiceCollectionExtensions
    {
{{GenerateAddOpenApiInterceptions(groupedAddOpenApiInvocations)}}
    }
}
""";
 
    internal static string GetAddOpenApiInterceptor(AddOpenApiOverloadVariant overloadVariant) => overloadVariant switch
    {
        AddOpenApiOverloadVariant.AddOpenApi => """
        public static IServiceCollection AddOpenApi(this IServiceCollection services)
                {
                    return services.AddOpenApi("v1", options =>
                    {
                        options.AddSchemaTransformer(new XmlCommentSchemaTransformer());
                        options.AddOperationTransformer(new XmlCommentOperationTransformer());
                    });
                }
        """,
        AddOpenApiOverloadVariant.AddOpenApiDocumentName => """
        public static IServiceCollection AddOpenApi(this IServiceCollection services, string documentName)
                {
                    return services.AddOpenApi(documentName, options =>
                    {
                        options.AddSchemaTransformer(new XmlCommentSchemaTransformer());
                        options.AddOperationTransformer(new XmlCommentOperationTransformer());
                    });
                }
        """,
        AddOpenApiOverloadVariant.AddOpenApiConfigureOptions => """
        public static IServiceCollection AddOpenApi(this IServiceCollection services, Action<OpenApiOptions> configureOptions)
                {
                    return services.AddOpenApi("v1", options =>
                    {
                        options.AddSchemaTransformer(new XmlCommentSchemaTransformer());
                        options.AddOperationTransformer(new XmlCommentOperationTransformer());
                        configureOptions(options);
                    });
                }
        """,
        AddOpenApiOverloadVariant.AddOpenApiDocumentNameConfigureOptions => """
        public static IServiceCollection AddOpenApi(this IServiceCollection services, string documentName, Action<OpenApiOptions> configureOptions)
                {
                    // This overload is not intercepted.
                    return OpenApiServiceCollectionExtensions.AddOpenApi(services, documentName, options =>
                    {
                        options.AddSchemaTransformer(new XmlCommentSchemaTransformer());
                        options.AddOperationTransformer(new XmlCommentOperationTransformer());
                        configureOptions(options);
                    });
                }
        """,
        _ => string.Empty // Effectively no-op for AddOpenApi invocations that do not conform to a variant
    };
 
    internal static string GenerateAddOpenApiInterceptions(ImmutableArray<(AddOpenApiInvocation Source, int Index, ImmutableArray<InterceptableLocation?> Elements)> groupedAddOpenApiInvocations)
    {
        var writer = new StringWriter();
        var codeWriter = new CodeWriter(writer, baseIndent: 2);
        foreach (var (source, _, locations) in groupedAddOpenApiInvocations)
        {
            foreach (var location in locations)
            {
                if (location is not null)
                {
                    codeWriter.WriteLine(location.GetInterceptsLocationAttributeSyntax());
                }
            }
            codeWriter.WriteLine(GetAddOpenApiInterceptor(source.Variant));
        }
        return writer.ToString();
    }
 
    internal static string EmitCommentsCache(IEnumerable<(MemberKey MemberKey, XmlComment? Comment)> comments, CancellationToken cancellationToken)
    {
        var writer = new StringWriter();
        var codeWriter = new CodeWriter(writer, baseIndent: 3);
        foreach (var (memberKey, comment) in comments)
        {
            if (comment is not null)
            {
                codeWriter.WriteLine($"_cache.Add(new MemberKey(" +
                    $"{FormatLiteralOrNull(memberKey.DeclaringType)}, " +
                    $"MemberType.{memberKey.MemberKind}, " +
                    $"{FormatLiteralOrNull(memberKey.Name, true)}, " +
                    $"{FormatLiteralOrNull(memberKey.ReturnType)}, " +
                    $"[{(memberKey.Parameters != null ? string.Join(", ", memberKey.Parameters.Select(p => SymbolDisplay.FormatLiteral(p, false))) : "")}]), " +
                    $"{EmitSourceGeneratedXmlComment(comment)});");
            }
        }
        return writer.ToString();
 
        static string FormatLiteralOrNull(string? input, bool quote = false)
        {
            return input == null ? "null" : SymbolDisplay.FormatLiteral(input, quote);
        }
    }
 
    private static string FormatStringForCode(string? input)
    {
        if (input == null)
        {
            return "null";
        }
 
        var formatted = input
            .Replace("\"", "\"\""); // Escape double quotes
 
        return $"@\"{formatted}\"";
    }
 
    internal static string EmitSourceGeneratedXmlComment(XmlComment comment)
    {
        var writer = new StringWriter();
        var codeWriter = new CodeWriter(writer, baseIndent: 0);
        codeWriter.Write($"new XmlComment(");
        codeWriter.Write(FormatStringForCode(comment.Summary) + ", ");
        codeWriter.Write(FormatStringForCode(comment.Description) + ", ");
        codeWriter.Write(FormatStringForCode(comment.Remarks) + ", ");
        codeWriter.Write(FormatStringForCode(comment.Returns) + ", ");
        codeWriter.Write(FormatStringForCode(comment.Value) + ", ");
        codeWriter.Write(comment.Deprecated == true ? "true" : "false" + ", ");
        if (comment.Examples is null || comment.Examples.Count == 0)
        {
            codeWriter.Write("null, ");
        }
        else
        {
            codeWriter.Write("[");
            for (int i = 0; i < comment.Examples.Count; i++)
            {
                var example = comment.Examples[i];
                codeWriter.Write(FormatStringForCode(example));
                if (i < comment.Examples.Count - 1)
                {
                    codeWriter.Write(", ");
                }
            }
            codeWriter.Write("], ");
        }
 
        if (comment.Parameters is null || comment.Parameters.Count == 0)
        {
            codeWriter.Write("null, ");
        }
        else
        {
            codeWriter.Write("[");
            for (int i = 0; i < comment.Parameters.Count; i++)
            {
                var parameter = comment.Parameters[i];
                var exampleLiteral = string.IsNullOrEmpty(parameter.Example)
                    ? "null"
                    : FormatStringForCode(parameter.Example!);
                codeWriter.Write($"new XmlParameterComment(@\"{parameter.Name}\", @\"{parameter.Description}\", {exampleLiteral}, {(parameter.Deprecated == true ? "true" : "false")})");
                if (i < comment.Parameters.Count - 1)
                {
                    codeWriter.Write(", ");
                }
            }
            codeWriter.Write("], ");
        }
 
        if (comment.Responses is null || comment.Responses.Count == 0)
        {
            codeWriter.Write("null");
        }
        else
        {
            codeWriter.Write("[");
            for (int i = 0; i < comment.Responses.Count; i++)
            {
                var response = comment.Responses[i];
                codeWriter.Write($"new XmlResponseComment(@\"{response.Code}\", @\"{response.Description}\", {(response.Example is null ? "null" : FormatStringForCode(response.Example))})");
                if (i < comment.Responses.Count - 1)
                {
                    codeWriter.Write(", ");
                }
            }
            codeWriter.Write("]");
        }
        codeWriter.Write(")");
        return writer.ToString();
    }
 
    internal static void Emit(SourceProductionContext context,
        string commentsFromXmlFile,
        string commentsFromCompilation,
        ImmutableArray<(AddOpenApiInvocation Source, int Index, ImmutableArray<InterceptableLocation?> Elements)> groupedAddOpenApiInvocations)
    {
        context.AddSource("OpenApiXmlCommentSupport.generated.cs", GenerateXmlCommentSupportSource(commentsFromXmlFile, commentsFromCompilation, groupedAddOpenApiInvocations));
    }
}