File: src\Analyzers\CSharp\CodeFixes\MakeMethodSynchronous\CSharpMakeMethodSynchronousCodeFixProvider.cs
Web Access
Project: src\src\CodeStyle\CSharp\CodeFixes\Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes.csproj (Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.RemoveAsyncModifier;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.MakeMethodSynchronous;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.CSharp.MakeMethodSynchronous;
 
using static CSharpSyntaxTokens;
 
[ExportCodeFixProvider(LanguageNames.CSharp, Name = PredefinedCodeFixProviderNames.MakeMethodSynchronous), Shared]
[ExtensionOrder(After = PredefinedCodeFixProviderNames.AddImport)]
[method: ImportingConstructor]
[method: SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
internal sealed class CSharpMakeMethodSynchronousCodeFixProvider() : AbstractMakeMethodSynchronousCodeFixProvider
{
    private const string CS1998 = nameof(CS1998); // This async method lacks 'await' operators and will run synchronously.
 
    public override ImmutableArray<string> FixableDiagnosticIds { get; } = [CS1998];
 
    protected override bool IsAsyncSupportingFunctionSyntax(SyntaxNode node)
        => node.IsAsyncSupportingFunctionSyntax();
 
    protected override SyntaxNode RemoveAsyncTokenAndFixReturnType(IMethodSymbol methodSymbol, SyntaxNode node, KnownTaskTypes knownTypes)
    {
        switch (node)
        {
            case MethodDeclarationSyntax method: return FixMethod(methodSymbol, method, knownTypes);
            case LocalFunctionStatementSyntax localFunction: return FixLocalFunction(methodSymbol, localFunction, knownTypes);
            case AnonymousMethodExpressionSyntax method: return RemoveAsyncModifierHelpers.WithoutAsyncModifier(method);
            case ParenthesizedLambdaExpressionSyntax lambda: return RemoveAsyncModifierHelpers.WithoutAsyncModifier(lambda);
            case SimpleLambdaExpressionSyntax lambda: return RemoveAsyncModifierHelpers.WithoutAsyncModifier(lambda);
            default: return node;
        }
    }
    private static SyntaxNode FixMethod(IMethodSymbol methodSymbol, MethodDeclarationSyntax method, KnownTaskTypes knownTypes)
    {
        var newReturnType = FixMethodReturnType(methodSymbol, method.ReturnType, knownTypes);
        return RemoveAsyncModifierHelpers.WithoutAsyncModifier(method, newReturnType);
    }
 
    private static SyntaxNode FixLocalFunction(IMethodSymbol methodSymbol, LocalFunctionStatementSyntax localFunction, KnownTaskTypes knownTypes)
    {
        var newReturnType = FixMethodReturnType(methodSymbol, localFunction.ReturnType, knownTypes);
        return RemoveAsyncModifierHelpers.WithoutAsyncModifier(localFunction, newReturnType);
    }
 
    private static TypeSyntax FixMethodReturnType(IMethodSymbol methodSymbol, TypeSyntax returnTypeSyntax, KnownTaskTypes knownTypes)
    {
        var newReturnType = returnTypeSyntax;
 
        var returnType = methodSymbol.ReturnType;
        if (returnType.OriginalDefinition.Equals(knownTypes.TaskType))
        {
            // If the return type is Task, then make the new return type "void".
            newReturnType = SyntaxFactory.PredefinedType(VoidKeyword).WithTriviaFrom(returnTypeSyntax);
        }
        else if (returnType.OriginalDefinition.Equals(knownTypes.TaskOfTType))
        {
            // If the return type is Task<T>, then make the new return type "T".
            newReturnType = returnType.GetTypeArguments()[0].GenerateTypeSyntax().WithTriviaFrom(returnTypeSyntax);
        }
        else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTType) &&
            knownTypes.IEnumerableOfTType != null)
        {
            // If the return type is IAsyncEnumerable<T>, then make the new return type IEnumerable<T>.
            newReturnType = knownTypes.IEnumerableOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
        }
        else if (returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTType) &&
            knownTypes.IEnumeratorOfTType != null)
        {
            // If the return type is IAsyncEnumerator<T>, then make the new return type IEnumerator<T>.
            newReturnType = knownTypes.IEnumeratorOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
        }
 
        return newReturnType;
    }
}