File: AddResponseTypeAttributeCodeFixAction.cs
Web Access
Project: src\src\Mvc\Mvc.Api.Analyzers\src\Microsoft.AspNetCore.Mvc.Api.Analyzers.csproj (Microsoft.AspNetCore.Mvc.Api.Analyzers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Simplification;
 
namespace Microsoft.AspNetCore.Mvc.Api.Analyzers;
 
/// <summary>
/// A <see cref="CodeAction"/> that adds one or more <c>ProducesResponseType</c> attributes on the action.
/// 1) It get status codes from ProducesResponseType, ProducesDefaultResponseType, and conventions applied to the action to get the declared metadata.
/// 2) It inspects return statements to get actual metadata.
/// Diffing the two gets us a list of undocumented status codes.
/// We'll attempt to generate a [ProducesResponseType(typeof(SomeModel), 4xx)] if
///     a) the status code is 4xx or later.
///     b) the return statement included a return type.
///     c) the return type wasn't the error type (specified by ProducesErrorResponseType or implicit ProblemDetails)
/// In all other cases, we generate [ProducesResponseType(StatusCode)]
/// </summary>
internal sealed class AddResponseTypeAttributeCodeFixAction : CodeAction
{
    private readonly Document _document;
    private readonly Diagnostic _diagnostic;
 
    public AddResponseTypeAttributeCodeFixAction(Document document, Diagnostic diagnostic)
    {
        _document = document;
        _diagnostic = diagnostic;
    }
 
    public override string EquivalenceKey => _diagnostic.Location.ToString();
 
    public override string Title => "Add ProducesResponseType attributes.";
 
    protected override async Task<Document> GetChangedDocumentAsync(CancellationToken cancellationToken)
    {
        var nullableContext = await CreateCodeActionContext(cancellationToken).ConfigureAwait(false);
        if (nullableContext == null)
        {
            return _document;
        }
 
        var context = nullableContext.Value;
 
        var declaredResponseMetadata = SymbolApiResponseMetadataProvider.GetDeclaredResponseMetadata(context.SymbolCache, context.Method);
        var errorResponseType = SymbolApiResponseMetadataProvider.GetErrorResponseType(context.SymbolCache, context.Method);
 
        var results = CalculateStatusCodesToApply(context, declaredResponseMetadata);
        if (results.Count == 0)
        {
            return _document;
        }
 
        var documentEditor = await DocumentEditor.CreateAsync(_document, cancellationToken).ConfigureAwait(false);
 
        var addUsingDirective = false;
        foreach (var item in results.OrderBy(s => s.statusCode))
        {
            var statusCode = item.statusCode;
            var returnType = item.typeSymbol;
 
            AttributeSyntax attributeSyntax;
            bool addUsing;
 
            if (statusCode >= 400 && returnType != null && !SymbolEqualityComparer.Default.Equals(returnType, errorResponseType))
            {
                // If a returnType was discovered and is different from the errorResponseType, use it in the result.
                attributeSyntax = CreateProducesResponseTypeAttribute(context, statusCode, returnType, out addUsing);
            }
            else
            {
                attributeSyntax = CreateProducesResponseTypeAttribute(context, statusCode, out addUsing);
            }
 
            documentEditor.AddAttribute(context.MethodSyntax, attributeSyntax);
            addUsingDirective |= addUsing;
        }
 
        if (!declaredResponseMetadata.Any(m => m.IsDefault && SymbolEqualityComparer.Default.Equals(m.AttributeSource, context.Method)))
        {
            // Add a ProducesDefaultResponseTypeAttribute if the method does not already have one.
            documentEditor.AddAttribute(context.MethodSyntax, CreateProducesDefaultResponseTypeAttribute());
        }
 
        var apiConventionMethodAttribute = context.Method.GetAttributes(context.SymbolCache.ApiConventionMethodAttribute).FirstOrDefault();
 
        if (apiConventionMethodAttribute != null)
        {
            // Remove [ApiConventionMethodAttribute] declared on the method since it's no longer required
            var attributeSyntax = await apiConventionMethodAttribute
                .ApplicationSyntaxReference
                .GetSyntaxAsync(cancellationToken)
                .ConfigureAwait(false);
 
            documentEditor.RemoveNode(attributeSyntax);
        }
 
        var document = documentEditor.GetChangedDocument();
 
        var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
 
        if (root is CompilationUnitSyntax compilationUnit && addUsingDirective)
        {
            const string @namespace = "Microsoft.AspNetCore.Http";
 
            var declaredUsings = new HashSet<string>(compilationUnit.Usings.Select(x => x.Name.ToString()));
 
            if (!declaredUsings.Contains(@namespace))
            {
                root = compilationUnit.AddUsings(SyntaxFactory.UsingDirective(SyntaxFactory.ParseName(@namespace)));
            }
        }
 
        return document.WithSyntaxRoot(root);
    }
 
    private async Task<CodeActionContext?> CreateCodeActionContext(CancellationToken cancellationToken)
    {
        var root = await _document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
        var semanticModel = await _document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var diagnosticNode = root.FindNode(_diagnostic.Location.SourceSpan);
        var methodSyntax = diagnosticNode.FirstAncestorOrSelf<MethodDeclarationSyntax>();
        var method = semanticModel.GetDeclaredSymbol(methodSyntax, cancellationToken);
 
        var statusCodesType = semanticModel.Compilation.GetTypeByMetadataName(ApiSymbolNames.HttpStatusCodes);
        var statusCodeConstants = GetStatusCodeConstants(statusCodesType);
 
        if (!ApiControllerSymbolCache.TryCreate(semanticModel.Compilation, out var symbolCache))
        {
            return null;
        }
 
        var codeActionContext = new CodeActionContext(semanticModel, symbolCache, method, methodSyntax, statusCodeConstants, cancellationToken);
        return codeActionContext;
    }
 
    private static Dictionary<int, string> GetStatusCodeConstants(INamedTypeSymbol statusCodesType)
    {
        var statusCodeConstants = new Dictionary<int, string>();
 
        if (statusCodesType != null)
        {
            foreach (var member in statusCodesType.GetMembers())
            {
                if (member is IFieldSymbol field &&
                    field.Type.SpecialType == SpecialType.System_Int32 &&
                    field.Name.StartsWith("Status", StringComparison.Ordinal) &&
                    field.HasConstantValue &&
                    field.ConstantValue is int statusCode)
                {
                    statusCodeConstants[statusCode] = field.Name;
                }
            }
        }
 
        return statusCodeConstants;
    }
 
    private static ICollection<(int statusCode, ITypeSymbol? typeSymbol)> CalculateStatusCodesToApply(in CodeActionContext context, IList<DeclaredApiResponseMetadata> declaredResponseMetadata)
    {
        var operation = (IMethodBodyBaseOperation)context.SemanticModel.GetOperation(context.MethodSyntax, context.CancellationToken);
 
        if (!ActualApiResponseMetadataFactory.TryGetActualResponseMetadata(context.SymbolCache, operation, out var actualResponseMetadata))
        {
            // If we cannot parse metadata correctly, don't offer fixes.
            return Array.Empty<(int, ITypeSymbol?)>();
        }
 
        var statusCodes = new Dictionary<int, (int, ITypeSymbol?)>();
        foreach (var metadata in actualResponseMetadata)
        {
            if (DeclaredApiResponseMetadata.TryGetDeclaredMetadata(declaredResponseMetadata, metadata, result: out var declaredMetadata) &&
                SymbolEqualityComparer.Default.Equals(declaredMetadata.AttributeSource, context.Method))
            {
                // A ProducesResponseType attribute is declared on the method for the current status code.
                continue;
            }
 
            var statusCode = metadata.IsDefaultResponse ? 200 : metadata.StatusCode;
            if (!statusCodes.ContainsKey(statusCode))
            {
                // If a status code appears multiple times in the actual metadata, pick the first one to
                // appear in the codefix
                statusCodes.Add(statusCode, (statusCode, metadata.ReturnType));
            }
        }
 
        return statusCodes.Values;
    }
 
    private static AttributeSyntax CreateProducesResponseTypeAttribute(in CodeActionContext context, int statusCode, out bool addUsingDirective)
    {
        // [ProducesResponseType(StatusCodes.Status400NotFound)]
        var statusCodeSyntax = CreateStatusCodeSyntax(context, statusCode, out addUsingDirective);
 
        return SyntaxFactory.Attribute(
            SyntaxFactory.ParseName(ApiSymbolNames.ProducesResponseTypeAttribute)
                .WithAdditionalAnnotations(Simplifier.Annotation),
            SyntaxFactory.AttributeArgumentList().AddArguments(
 
                SyntaxFactory.AttributeArgument(statusCodeSyntax)));
    }
 
    private static AttributeSyntax CreateProducesResponseTypeAttribute(in CodeActionContext context, int statusCode, ITypeSymbol typeSymbol, out bool addUsingDirective)
    {
        // [ProducesResponseType(typeof(ReturnType), StatusCodes.Status400NotFound)]
        var statusCodeSyntax = CreateStatusCodeSyntax(context, statusCode, out addUsingDirective);
        var responseTypeAttribute = SyntaxFactory.TypeOfExpression(
            SyntaxFactory.ParseTypeName(typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
                .WithAdditionalAnnotations(Simplifier.Annotation));
 
        return SyntaxFactory.Attribute(
            SyntaxFactory.ParseName(ApiSymbolNames.ProducesResponseTypeAttribute)
                .WithAdditionalAnnotations(Simplifier.Annotation),
            SyntaxFactory.AttributeArgumentList().AddArguments(
                SyntaxFactory.AttributeArgument(responseTypeAttribute),
                SyntaxFactory.AttributeArgument(statusCodeSyntax)));
    }
 
    private static ExpressionSyntax CreateStatusCodeSyntax(CodeActionContext context, int statusCode, out bool addUsingDirective)
    {
        if (context.StatusCodeConstants.TryGetValue(statusCode, out var constantName))
        {
            addUsingDirective = true;
            return SyntaxFactory.MemberAccessExpression(
                SyntaxKind.SimpleMemberAccessExpression,
                SyntaxFactory.ParseTypeName(ApiSymbolNames.HttpStatusCodes)
                    .WithAdditionalAnnotations(Simplifier.Annotation),
                SyntaxFactory.IdentifierName(constantName));
        }
 
        addUsingDirective = false;
        return SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal(statusCode));
    }
 
    private static AttributeSyntax CreateProducesDefaultResponseTypeAttribute()
    {
        return SyntaxFactory.Attribute(
            SyntaxFactory.ParseName(ApiSymbolNames.ProducesDefaultResponseTypeAttribute)
                .WithAdditionalAnnotations(Simplifier.Annotation));
    }
 
    private readonly struct CodeActionContext
    {
        public CodeActionContext(SemanticModel semanticModel,
            ApiControllerSymbolCache symbolCache,
            IMethodSymbol method,
            MethodDeclarationSyntax methodSyntax,
            Dictionary<int, string> statusCodeConstants,
            CancellationToken cancellationToken)
        {
            SemanticModel = semanticModel;
            SymbolCache = symbolCache;
            Method = method;
            MethodSyntax = methodSyntax;
            StatusCodeConstants = statusCodeConstants;
            CancellationToken = cancellationToken;
        }
 
        public MethodDeclarationSyntax MethodSyntax { get; }
 
        public Dictionary<int, string> StatusCodeConstants { get; }
 
        public IMethodSymbol Method { get; }
 
        public SemanticModel SemanticModel { get; }
 
        public ApiControllerSymbolCache SymbolCache { get; }
 
        public CancellationToken CancellationToken { get; }
    }
}