File: CodeRefactorings\AddAwait\AbstractAddAwaitCodeRefactoringProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.CodeRefactorings.AddAwait;
 
/// <summary>
/// Refactor:
///     var x = GetAsync();
///
/// Into:
///     var x = await GetAsync();
///
/// Or:
///     var x = await GetAsync().ConfigureAwait(false);
/// </summary>
internal abstract class AbstractAddAwaitCodeRefactoringProvider<TExpressionSyntax> : CodeRefactoringProvider
    where TExpressionSyntax : SyntaxNode
{
    protected abstract string GetTitle();
    protected abstract string GetTitleWithConfigureAwait();
 
    protected abstract bool IsInAsyncContext(SyntaxNode node);
 
    public sealed override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, span, cancellationToken) = context;
 
        var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
        var node = root.FindNode(span);
        if (!IsInAsyncContext(node))
            return;
 
        var model = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
 
        var expressions = await context.GetRelevantNodesAsync<TExpressionSyntax>().ConfigureAwait(false);
        for (var i = expressions.Length - 1; i >= 0; i--)
        {
            var expression = expressions[i];
            if (IsValidAwaitableExpression(model, syntaxFacts, expression, cancellationToken))
            {
                var title = GetTitle();
                context.RegisterRefactoring(
                    CodeAction.Create(
                        title,
                        cancellationToken => AddAwaitAsync(document, expression, withConfigureAwait: false, cancellationToken),
                        title),
                    expression.Span);
 
                var titleWithConfigureAwait = GetTitleWithConfigureAwait();
                context.RegisterRefactoring(
                    CodeAction.Create(
                        titleWithConfigureAwait,
                        cancellationToken => AddAwaitAsync(document, expression, withConfigureAwait: true, cancellationToken),
                        titleWithConfigureAwait),
                    expression.Span);
            }
        }
    }
 
    private static bool IsValidAwaitableExpression(
        SemanticModel model, ISyntaxFactsService syntaxFacts, SyntaxNode node, CancellationToken cancellationToken)
    {
        if (syntaxFacts.IsExpressionOfInvocationExpression(node.Parent))
        {
            // Do not offer fix on `MethodAsync()$$.ConfigureAwait()`
            // Do offer fix on `MethodAsync()$$.Invalid()`
            if (!model.GetTypeInfo(node.GetRequiredParent().GetRequiredParent(), cancellationToken).Type.IsErrorType())
                return false;
        }
 
        if (syntaxFacts.IsExpressionOfAwaitExpression(node))
            return false;
 
        for (var current = node; current != null;)
        {
            if (syntaxFacts.IsMemberBindingExpression(current) ||
                syntaxFacts.IsElementBindingExpression(current))
            {
                // Can't add 'await' to the `.X` in `a?.X`.  Nor would we want to.  Those could return null, which
                // `await` would blow up on.  Note: this could be reconsidered if we end up adding `await?` support to
                // the language.
                return false;
            }
 
            current = current.ChildNodesAndTokens().FirstOrDefault().AsNode() as TExpressionSyntax;
        }
 
        // if we're on an actual type symbol itself (like literally `Task`) we don't want to offer to add await.
        // we only want to add for actual expressions whose type is awaitable, not on the awaitable type itself.
        var symbol = model.GetSymbolInfo(node, cancellationToken).GetAnySymbol();
        if (symbol is ITypeSymbol)
            return false;
 
        var type = model.GetTypeInfo(node, cancellationToken).Type;
        return type?.IsAwaitableNonDynamic(model, node.SpanStart) == true;
    }
 
    private static Task<Document> AddAwaitAsync(
        Document document,
        TExpressionSyntax expression,
        bool withConfigureAwait,
        CancellationToken cancellationToken)
    {
        var generator = SyntaxGenerator.GetGenerator(document);
        var withoutTrivia = expression.WithoutTrivia();
        withoutTrivia = (TExpressionSyntax)generator.AddParentheses(withoutTrivia);
        if (withConfigureAwait)
        {
            withoutTrivia = (TExpressionSyntax)generator.InvocationExpression(
                generator.MemberAccessExpression(withoutTrivia, nameof(Task.ConfigureAwait)),
                generator.FalseLiteralExpression());
        }
 
        var awaitExpression = generator
            .AddParentheses(generator.AwaitExpression(withoutTrivia))
            .WithTriviaFrom(expression);
 
        return document.ReplaceNodeAsync(expression, awaitExpression, cancellationToken);
    }
}