File: src\Analyzers\Core\CodeFixes\RemoveAsyncModifier\AbstractRemoveAsyncModifierCodeFixProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Host;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.RemoveAsyncModifier;
 
internal abstract class AbstractRemoveAsyncModifierCodeFixProvider<TReturnStatementSyntax, TExpressionSyntax> : SyntaxEditorBasedCodeFixProvider
    where TReturnStatementSyntax : SyntaxNode
    where TExpressionSyntax : SyntaxNode
{
    protected abstract bool IsAsyncSupportingFunctionSyntax(SyntaxNode node);
    protected abstract SyntaxNode RemoveAsyncModifier(SyntaxGenerator generator, SyntaxNode methodLikeNode);
    protected abstract SyntaxNode? ConvertToBlockBody(SyntaxNode node, TExpressionSyntax expressionBody);
 
    public override async Task RegisterCodeFixesAsync(CodeFixContext context)
    {
        var document = context.Document;
        var cancellationToken = context.CancellationToken;
        var compilation = await document.Project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
        var knownTypes = new KnownTaskTypes(compilation);
 
        var diagnostic = context.Diagnostics.First();
        var token = diagnostic.Location.FindToken(cancellationToken);
        var node = token.GetAncestor(IsAsyncSupportingFunctionSyntax);
        if (node == null)
        {
            return;
        }
 
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var methodSymbol = GetMethodSymbol(node, semanticModel, cancellationToken);
 
        if (methodSymbol == null)
        {
            return;
        }
 
        if (ShouldOfferFix(methodSymbol.ReturnType, knownTypes))
        {
            context.RegisterCodeFix(
                CodeAction.Create(
                    CodeFixesResources.Remove_async_modifier,
                    GetDocumentUpdater(context),
                    nameof(CodeFixesResources.Remove_async_modifier)),
                context.Diagnostics);
        }
    }
 
    protected sealed override async Task FixAllAsync(
        Document document, ImmutableArray<Diagnostic> diagnostics,
        SyntaxEditor editor, CancellationToken cancellationToken)
    {
        var solutionServices = document.Project.Solution.Services;
        var generator = editor.Generator;
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        var compilation = semanticModel.Compilation;
        var knownTypes = new KnownTaskTypes(compilation);
 
        // For fix all we need to do nested locals or lambdas first, so order the diagnostics by location descending
        foreach (var diagnostic in diagnostics.OrderByDescending(d => d.Location.SourceSpan.Start))
        {
            var token = diagnostic.Location.FindToken(cancellationToken);
            var node = token.GetAncestor(IsAsyncSupportingFunctionSyntax);
            if (node == null)
            {
                Debug.Fail("We should always be able to find the node from the diagnostic.");
                continue;
            }
 
            var methodSymbol = GetMethodSymbol(node, semanticModel, cancellationToken);
            if (methodSymbol == null)
            {
                Debug.Fail("We should always be able to find the method symbol for the diagnostic.");
                continue;
            }
 
            // We might need to perform control flow analysis as part of the fix, so we need to do it on the original node
            // so do it up front. Nothing in the fixer changes the reachability of the end of the method so this is safe
            var controlFlow = GetControlFlowAnalysis(generator, semanticModel, node);
            // If control flow couldn't be computed then its probably an empty block, which means we need to add a return anyway
            var needsReturnStatementAdded = controlFlow == null || controlFlow.EndPointIsReachable;
 
            editor.ReplaceNode(node,
                (updatedNode, generator) => RemoveAsyncModifier(
                    solutionServices, syntaxFacts, generator, updatedNode, methodSymbol.ReturnType, knownTypes, needsReturnStatementAdded));
        }
    }
 
    private static IMethodSymbol? GetMethodSymbol(SyntaxNode node, SemanticModel semanticModel, CancellationToken cancellationToken)
        => semanticModel.GetSymbolInfo(node, cancellationToken).Symbol as IMethodSymbol ??
           semanticModel.GetDeclaredSymbol(node, cancellationToken) as IMethodSymbol;
 
    private static bool ShouldOfferFix(ITypeSymbol returnType, KnownTaskTypes knownTypes)
        => IsTaskType(returnType, knownTypes)
            || returnType.OriginalDefinition.Equals(knownTypes.TaskOfTType)
            || returnType.OriginalDefinition.Equals(knownTypes.ValueTaskOfTType);
 
    private static bool IsTaskType(ITypeSymbol returnType, KnownTaskTypes knownTypes)
        => returnType.OriginalDefinition.Equals(knownTypes.TaskType)
            || returnType.OriginalDefinition.Equals(knownTypes.ValueTaskType);
 
    private SyntaxNode RemoveAsyncModifier(
        SolutionServices solutionServices,
        ISyntaxFacts syntaxFacts,
        SyntaxGenerator generator,
        SyntaxNode node,
        ITypeSymbol returnType,
        KnownTaskTypes knownTypes,
        bool needsReturnStatementAdded)
    {
        node = RemoveAsyncModifier(generator, node);
 
        var expression = generator.GetExpression(node);
        if (expression is TExpressionSyntax expressionBody)
        {
            if (IsTaskType(returnType, knownTypes))
            {
                // We need to add a `return Task.CompletedTask;` so we have to convert to a block body
                var blockBodiedNode = ConvertToBlockBody(node, expressionBody);
 
                // Expression bodied members can't have return statements so if we can't convert to a block
                // body then we've done all we can
                if (blockBodiedNode != null)
                {
                    node = AddReturnStatement(generator, blockBodiedNode);
                }
            }
            else
            {
                // For Task<T> returning expression bodied methods we can just wrap the whole expression
                var newExpressionBody = WrapExpressionWithTaskFromResult(generator, expressionBody, returnType, knownTypes);
                node = generator.WithExpression(node, newExpressionBody);
            }
        }
        else
        {
            if (IsTaskType(returnType, knownTypes))
            {
                // If the end of the method isn't reachable, or there were no statements to analyze, then we
                // need to add an explicit return
                if (needsReturnStatementAdded)
                {
                    node = AddReturnStatement(generator, node);
                }
            }
        }
 
        return ChangeReturnStatements(solutionServices, syntaxFacts, generator, node, returnType, knownTypes);
    }
 
    private static ControlFlowAnalysis? GetControlFlowAnalysis(SyntaxGenerator generator, SemanticModel semanticModel, SyntaxNode node)
    {
        var statements = generator.GetStatements(node);
        if (statements.Count > 0)
        {
            return semanticModel.AnalyzeControlFlow(statements[0], statements[statements.Count - 1]);
        }
 
        return null;
    }
 
    private static SyntaxNode AddReturnStatement(SyntaxGenerator generator, SyntaxNode node)
        => generator.WithStatements(node, generator.GetStatements(node).Concat(generator.ReturnStatement()));
 
    private SyntaxNode ChangeReturnStatements(
        SolutionServices solutionServices,
        ISyntaxFacts syntaxFacts,
        SyntaxGenerator generator,
        SyntaxNode node,
        ITypeSymbol returnType,
        KnownTaskTypes knownTypes)
    {
        var editor = new SyntaxEditor(node, solutionServices);
 
        // Look for all return statements, but if we find a new node that can have the async modifier we stop
        // because that will have its own diagnostic and fix, if applicable
        var returns = node.DescendantNodes(n => n == node || !IsAsyncSupportingFunctionSyntax(n)).OfType<TReturnStatementSyntax>();
 
        foreach (var returnSyntax in returns)
        {
            var returnExpression = syntaxFacts.GetExpressionOfReturnStatement(returnSyntax);
            if (returnExpression is null)
            {
                // Convert return; into return Task.CompletedTask;
                var returnTaskCompletedTask = GetReturnTaskCompletedTaskStatement(generator, returnType, knownTypes);
                editor.ReplaceNode(returnSyntax, returnTaskCompletedTask);
            }
            else
            {
                // Convert return <expr>; into return Task.FromResult(<expr>);
                var newExpression = WrapExpressionWithTaskFromResult(generator, returnExpression, returnType, knownTypes);
                editor.ReplaceNode(returnExpression, newExpression);
            }
        }
 
        return editor.GetChangedRoot();
    }
 
    private static SyntaxNode GetReturnTaskCompletedTaskStatement(SyntaxGenerator generator, ITypeSymbol returnType, KnownTaskTypes knownTypes)
    {
        SyntaxNode invocation;
        if (returnType.OriginalDefinition.Equals(knownTypes.TaskType))
        {
            var taskTypeExpression = TypeExpressionForStaticMemberAccess(generator, knownTypes.TaskType);
            invocation = generator.MemberAccessExpression(taskTypeExpression, nameof(Task.CompletedTask));
        }
        else
        {
            invocation = generator.ObjectCreationExpression(knownTypes.ValueTaskType!);
        }
 
        var statement = generator.ReturnStatement(invocation);
        return statement;
    }
 
    private static SyntaxNode WrapExpressionWithTaskFromResult(SyntaxGenerator generator, SyntaxNode expression, ITypeSymbol returnType, KnownTaskTypes knownTypes)
    {
        if (returnType.OriginalDefinition.Equals(knownTypes.TaskOfTType))
        {
            var taskTypeExpression = TypeExpressionForStaticMemberAccess(generator, knownTypes.TaskType!);
            var unwrappedReturnType = returnType.GetTypeArguments()[0];
            var memberName = generator.GenericName(nameof(Task.FromResult), unwrappedReturnType);
            var taskFromResult = generator.MemberAccessExpression(taskTypeExpression, memberName);
            return generator.InvocationExpression(taskFromResult, expression.WithoutTrivia()).WithTriviaFrom(expression);
        }
        else
        {
            return generator.ObjectCreationExpression(returnType, expression);
        }
    }
 
    // Workaround for https://github.com/dotnet/roslyn/issues/43950
    // Copied from https://github.com/dotnet/roslyn-analyzers/blob/f24a5b42c85be6ee572f3a93bef223767fbefd75/src/Utilities/Workspaces/SyntaxGeneratorExtensions.cs#L68-L74
    private static SyntaxNode TypeExpressionForStaticMemberAccess(SyntaxGenerator generator, INamedTypeSymbol typeSymbol)
    {
        var qualifiedNameSyntaxKind = generator.QualifiedName(generator.IdentifierName("ignored"), generator.IdentifierName("ignored")).RawKind;
        var memberAccessExpressionSyntaxKind = generator.MemberAccessExpression(generator.IdentifierName("ignored"), "ignored").RawKind;
 
        var typeExpression = generator.TypeExpression(typeSymbol);
        return QualifiedNameToMemberAccess(qualifiedNameSyntaxKind, memberAccessExpressionSyntaxKind, typeExpression, generator);
 
        // Local function
        static SyntaxNode QualifiedNameToMemberAccess(int qualifiedNameSyntaxKind, int memberAccessExpressionSyntaxKind, SyntaxNode expression, SyntaxGenerator generator)
        {
            if (expression.RawKind == qualifiedNameSyntaxKind)
            {
                var left = QualifiedNameToMemberAccess(qualifiedNameSyntaxKind, memberAccessExpressionSyntaxKind, expression.ChildNodes().First(), generator);
                var right = expression.ChildNodes().Last();
                return generator.MemberAccessExpression(left, right);
            }
 
            return expression;
        }
    }
}