File: InlineMethod\AbstractInlineMethodRefactoringProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Collections;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.InlineMethod;
 
internal abstract partial class AbstractInlineMethodRefactoringProvider<
        TMethodDeclarationSyntax,
        TStatementSyntax,
        TExpressionSyntax,
        TInvocationSyntax>(
    ISyntaxFacts syntaxFacts,
    ISemanticFactsService semanticFactsService)
    : CodeRefactoringProvider
    where TMethodDeclarationSyntax : SyntaxNode
    where TStatementSyntax : SyntaxNode
    where TExpressionSyntax : SyntaxNode
    where TInvocationSyntax : TExpressionSyntax
{
    /// <summary>
    /// A preferred name used to generated a declaration when the
    /// inline method's body is not a valid expression in ExpressionStatement
    /// Example:
    /// void Caller()
    /// {
    ///     Callee();
    /// }
    /// int Callee()
    /// {
    ///     return 1;
    /// };
    /// After it should be:
    /// void Caller()
    /// {
    ///     int temp = 1;
    /// }
    /// int Callee()
    /// {
    ///     return 1;
    /// };
    /// '1' is not a valid expression in ExpressionStatement so a declaration is needed to be generated.
    /// </summary>
    private const string TemporaryName = "temp";
 
    private readonly ISyntaxFacts _syntaxFacts = syntaxFacts;
    private readonly ISemanticFactsService _semanticFactsService = semanticFactsService;
 
    protected abstract TExpressionSyntax? GetRawInlineExpression(TMethodDeclarationSyntax calleeMethodDeclarationSyntaxNode);
    protected abstract SyntaxNode GenerateTypeSyntax(ITypeSymbol symbol, bool allowVar);
    protected abstract TExpressionSyntax GenerateLiteralExpression(ITypeSymbol typeSymbol, object? value);
    protected abstract bool IsFieldDeclarationSyntax(SyntaxNode node);
 
    /// <summary>
    /// Check if <paramref name="expressionNode"/> could be used as an Expression in ExpressionStatement
    /// </summary>
    protected abstract bool IsValidExpressionUnderExpressionStatement(TExpressionSyntax expressionNode);
 
    /// <summary>
    /// Check if <paramref name="syntaxNode"/> could be replaced by ThrowExpression.
    /// For VB it always return false because ThrowExpression doesn't exist.
    /// </summary>
    protected abstract bool CanBeReplacedByThrowExpression(SyntaxNode syntaxNode);
 
    internal override CodeRefactoringKind Kind => CodeRefactoringKind.Inline;
 
    public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, _, cancellationToken) = context;
        var calleeInvocationNode = await context.TryGetRelevantNodeAsync<TInvocationSyntax>().ConfigureAwait(false);
        if (calleeInvocationNode == null)
            return;
 
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        if (semanticModel.GetSymbolInfo(calleeInvocationNode, cancellationToken).GetAnySymbol() is not IMethodSymbol calleeMethodSymbol)
            return;
 
        calleeMethodSymbol = calleeMethodSymbol.PartialImplementationPart ?? calleeMethodSymbol;
        if (!calleeMethodSymbol.IsOrdinaryMethod() && !calleeMethodSymbol.IsExtensionMethod)
            return;
 
        if (calleeMethodSymbol.IsVararg)
            return;
 
        if (calleeMethodSymbol.DeclaredAccessibility != Accessibility.Private)
            return;
 
        var symbolDeclarationService = document.GetRequiredLanguageService<ISymbolDeclarationService>();
        if (symbolDeclarationService.GetDeclarations(calleeMethodSymbol) is not [var calleeMethodDeclarationSyntaxReference])
            return;
 
        if (await calleeMethodDeclarationSyntaxReference.GetSyntaxAsync(cancellationToken).ConfigureAwait(false) is not TMethodDeclarationSyntax calleeMethodNode)
            return;
 
        var rawInlineExpression = GetRawInlineExpression(calleeMethodNode);
 
        // Special case 1: AwaitExpression
        if (_syntaxFacts.IsAwaitExpression(rawInlineExpression))
        {
            // 1. If Caller & callee both have 'await' make sure there is no duplicate 'await'
            // Example:
            // Before:
            // async Task Caller() => await Callee();
            // async Task Callee() => await Task.CompletedTask;
            // After:
            // async Task Caller() => await Task.CompletedTask;
            // async Task Callee() => await Task.CompletedTask;
            // The original inline expression in callee will be 'await Task.CompletedTask'
            // The caller just need 'Task.CompletedTask' without the 'await'
            //
            // 2. If Caller doesn't have await but callee has.
            // Example:
            // Before:
            // void Caller() { Callee().Wait();}
            // async Task Callee() => await DoAsync();
            // After:
            // void Caller() { DoAsync().Wait(); }
            // async Task Callee() => await DoAsync();
            // What caller is expecting is an expression returns 'Task', which doesn't include the 'await'
            rawInlineExpression = _syntaxFacts.GetExpressionOfAwaitExpression(rawInlineExpression) as TExpressionSyntax;
        }
 
        if (rawInlineExpression == null)
            return;
 
        // Special case 2: ThrowStatement & ThrowExpresion
        if (_syntaxFacts.IsThrowStatement(rawInlineExpression.Parent) || _syntaxFacts.IsThrowExpression(rawInlineExpression))
        {
            // If this is a throw statement, then it should be valid for
            // 1. If it is invoked as ExpressionStatement
            // Example:
            // Before:
            // void Caller() { Callee(); }
            // void Callee() { throw new Exception();}
            // After:
            // void Caller() { throw new Exception(); }
            // void Callee() { throw new Exception();}
            // 2. If it is invoked in a place allow throw expression
            // Example:
            // Before:
            // void Caller(bool flag) { var x = flag ? Callee() : 1; }
            // int Callee() { throw new Exception();}
            // After:
            // void Caller() { var x = flag ? throw new Exception() : 1; }
            // int Callee() { throw new Exception();}
            // Note here throw statement is changed to throw expression after inlining
            // If this is a throw expression, the check is the same
            // 1. If it is invoked as ExpressionStatement
            // Example:
            // Before:
            // void Caller() { Callee(); }
            // void Callee() => throw new Exception();
            // After:
            // void Caller() { throw new Exception(); }
            // void Callee() => throw new Exception();
            // Note here throw expression is converted to throw statement
            // 2. If it is invoked in a place allow throw expression
            // Example:
            // Before:
            // void Caller(bool flag) { var x = flag ? Callee() : 1; }
            // int Callee() => throw new Exception();
            // After:
            // void Caller() { var x = flag ? throw new Exception() : 1; }
            // int Callee() => throw new Exception();
            if (!CanBeReplacedByThrowExpression(calleeInvocationNode)
                && !_syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent))
            {
                return;
            }
        }
 
        var callerSymbol = GetCallerSymbol(calleeInvocationNode, semanticModel, cancellationToken);
        if (callerSymbol == null)
            return;
 
        if (symbolDeclarationService.GetDeclarations(callerSymbol) is not [var callerReference])
            return;
 
        var callerDeclarationNode = await callerReference.GetSyntaxAsync(cancellationToken).ConfigureAwait(false);
        if (semanticModel.GetOperation(calleeInvocationNode, cancellationToken) is not IInvocationOperation invocationOperation)
            return;
 
        var syntaxGenerator = SyntaxGenerator.GetGenerator(document);
 
        context.RegisterRefactoring(CodeAction.Create(
                string.Format(FeaturesResources.Inline_0, calleeMethodSymbol.ToNameDisplayString()),
                GenerateCodeActions(),
                isInlinable: true),
            calleeInvocationNode.Span);
 
        ImmutableArray<CodeAction> GenerateCodeActions()
        {
            using var result = TemporaryArray<CodeAction>.Empty;
 
            var calleeMethodName = calleeMethodSymbol.ToNameDisplayString();
 
            // For recursive calls (caller and callee are the same method), we can't offer the
            // "Inline_" option because we can't remove a method while also modifying it.
            if (!SymbolEqualityComparer.Default.Equals(callerSymbol, calleeMethodSymbol))
            {
                result.Add(CodeAction.Create(
                    string.Format(FeaturesResources.Inline_0, calleeMethodName),
                    cancellationToken => InlineMethodAsync(
                        removeCalleeDeclarationNode: true,
                        cancellationToken)));
            }
 
            result.Add(CodeAction.Create(
                string.Format(FeaturesResources.Inline_and_keep_0, calleeMethodName),
                cancellationToken => InlineMethodAsync(
                    removeCalleeDeclarationNode: false,
                    cancellationToken)));
 
            return result.ToImmutableAndClear();
        }
 
        async Task<Solution> InlineMethodAsync(
            bool removeCalleeDeclarationNode,
            CancellationToken cancellationToken)
        {
            // Find the statement contains the invocation. This should happen when Callee is invoked in a block
            // example:
            // void Caller()
            // {
            //     Action a = () =>
            //     {
            //         var x = Callee();
            //     }
            // } (Local declaration x is the containing node)
            // Note: Stop the searching when it hits lambda or local function, because for this case below don't
            // treat the declaration of a is the containing node
            // void Caller()
            // {
            //     Action a = () => Callee();
            // }
            // it could be null if the caller is invoked as arrow function
            var statementContainsInvocation = calleeInvocationNode.GetAncestors()
                .TakeWhile(node => !_syntaxFacts.IsAnonymousFunctionExpression(node) && !_syntaxFacts.IsLocalFunctionStatement(node))
                .FirstOrDefault(node => node is TStatementSyntax) as TStatementSyntax;
 
            var methodParametersInfo = await GetMethodParametersInfoAsync(
                document,
                calleeInvocationNode,
                calleeMethodNode,
                statementContainsInvocation,
                rawInlineExpression,
                invocationOperation, cancellationToken).ConfigureAwait(false);
 
            var inlineContext = await GetInlineMethodContextAsync(
                document,
                calleeMethodNode,
                calleeInvocationNode,
                calleeMethodSymbol,
                rawInlineExpression,
                methodParametersInfo,
                cancellationToken).ConfigureAwait(false);
 
            var solution = document.Project.Solution;
            var solutionEditor = new SolutionEditor(solution);
            if (removeCalleeDeclarationNode)
            {
                var calleeDocumentId = solution.GetDocumentId(calleeMethodNode.SyntaxTree);
                if (calleeDocumentId != null)
                {
                    var calleeDocumentEditor = await solutionEditor.GetDocumentEditorAsync(calleeDocumentId, cancellationToken).ConfigureAwait(false);
                    calleeDocumentEditor.RemoveNode(calleeMethodNode);
                }
            }
 
            var newCallerMethodNode = await GetChangedCallerAsync(
                statementContainsInvocation, methodParametersInfo, inlineContext, cancellationToken).ConfigureAwait(false);
 
            var callerDocumentEditor = await solutionEditor.GetDocumentEditorAsync(document.Id, cancellationToken).ConfigureAwait(false);
            callerDocumentEditor.ReplaceNode(callerDeclarationNode, newCallerMethodNode);
 
            return solutionEditor.GetChangedSolution();
        }
 
        async Task<SyntaxNode> GetChangedCallerAsync(
            TStatementSyntax? statementContainsInvocation,
            MethodParametersInfo methodParametersInfo,
            InlineMethodContext inlineMethodContext,
            CancellationToken cancellationToken)
        {
            var callerNodeEditor = new SyntaxEditor(callerDeclarationNode, syntaxGenerator);
 
            if (inlineMethodContext.ContainsAwaitExpression)
            {
                // If the inline content has 'await' expression, then make sure the caller is changed to 'async' method
                // if its return type is awaitable. In all other cases, do nothing.
                if (callerSymbol is IMethodSymbol { MethodKind: MethodKind.Ordinary, IsAsync: false } callerMethodSymbol
                    && (callerMethodSymbol.ReturnsVoid
                        || callerMethodSymbol.IsAwaitableNonDynamic(semanticModel, callerDeclarationNode.SpanStart)))
                {
                    var declarationModifiers = DeclarationModifiers.From(callerSymbol).WithAsync(true);
                    callerNodeEditor.SetModifiers(callerDeclarationNode, declarationModifiers);
                }
            }
 
            if (statementContainsInvocation != null)
            {
                foreach (var statement in inlineMethodContext.StatementsToInsertBeforeInvocationOfCallee)
                {
                    // Add a CarriageReturn to make sure for VB the statement would be in different line.
                    callerNodeEditor.InsertBefore(statementContainsInvocation,
                        statement.WithAppendedTrailingTrivia(_syntaxFacts.ElasticCarriageReturnLineFeed));
                }
            }
 
            var (nodeToReplace, inlineNode) = GetInlineNode(
                semanticModel,
                statementContainsInvocation,
                methodParametersInfo,
                inlineMethodContext,
                cancellationToken);
            callerNodeEditor.ReplaceNode(nodeToReplace, (node, generator) => inlineNode);
 
            return callerNodeEditor.GetChangedRoot();
        }
 
        (SyntaxNode nodeToReplace, SyntaxNode inlineNode) GetInlineNode(
            SemanticModel semanticModel,
            TStatementSyntax? statementContainsInvocation,
            MethodParametersInfo methodParametersInfo,
            InlineMethodContext inlineMethodContext,
            CancellationToken cancellationToken)
        {
            if (statementContainsInvocation != null)
            {
                if (methodParametersInfo.MergeInlineContentAndVariableDeclarationArgument)
                {
                    var rightHandSideValue = _syntaxFacts.GetRightHandSideOfAssignment(inlineMethodContext.InlineExpression);
                    var (parameterSymbol, name) = methodParametersInfo.ParametersWithVariableDeclarationArgument.Single();
                    var declarationNode = (TStatementSyntax)syntaxGenerator
                        .LocalDeclarationStatement(parameterSymbol.Type, name, rightHandSideValue);
                    return (statementContainsInvocation, declarationNode.WithTriviaFrom(statementContainsInvocation));
                }
 
                if (_syntaxFacts.IsThrowStatement(rawInlineExpression.Parent)
                    && _syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent))
                {
                    var throwStatement = (TStatementSyntax)syntaxGenerator
                        .ThrowStatement(inlineMethodContext.InlineExpression);
                    return (statementContainsInvocation, throwStatement.WithTriviaFrom(statementContainsInvocation));
                }
 
                if (_syntaxFacts.IsThrowExpression(rawInlineExpression)
                    && _syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent))
                {
                    // Example:
                    // Before:
                    // void Caller() { Callee(); }
                    // void Callee() => throw new Exception();
                    // After:
                    // void Caller() { throw new Exception(); }
                    // void Callee() => throw new Exception();
                    // Note: Throw expression is converted to throw statement
                    var throwStatement = (TStatementSyntax)syntaxGenerator
                        .ThrowStatement(_syntaxFacts.GetExpressionOfThrowExpression(inlineMethodContext.InlineExpression));
                    return (statementContainsInvocation, throwStatement.WithTriviaFrom(statementContainsInvocation));
                }
 
                if (_syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent)
                    && !calleeMethodSymbol.ReturnsVoid
                    && !IsValidExpressionUnderExpressionStatement(inlineMethodContext.InlineExpression))
                {
                    // If the callee is invoked as ExpressionStatement, but the inlined expression in the callee can't be
                    // placed under ExpressionStatement
                    // Example:
                    // void Caller()
                    // {
                    //     Callee();
                    // }
                    // int Callee()
                    // {
                    //     return 1;
                    // };
                    // After it should be:
                    // void Caller()
                    // {
                    //     int temp = 1;
                    // }
                    // int Callee()
                    // {
                    //     return 1;
                    // };
                    // One variable declaration needs to be generated.
                    var unusedLocalName =
                        _semanticFactsService.GenerateUniqueLocalName(
                            semanticModel,
                            calleeInvocationNode,
                            container: null,
                            TemporaryName,
                            cancellationToken);
 
                    var localDeclarationNode = (TStatementSyntax)syntaxGenerator
                        .LocalDeclarationStatement(calleeMethodSymbol.ReturnType, unusedLocalName.Text,
                            inlineMethodContext.InlineExpression);
                    return (statementContainsInvocation, localDeclarationNode.WithTriviaFrom(statementContainsInvocation));
                }
            }
 
            if (_syntaxFacts.IsThrowStatement(rawInlineExpression.Parent))
            {
                // Example:
                // Before:
                // void Caller() => Callee();
                // void Callee() { throw new Exception(); }
                // After:
                // void Caller() => throw new Exception();
                // void Callee() { throw new Exception(); }
                // Note: Throw statement is converted to throw expression
                if (CanBeReplacedByThrowExpression(calleeInvocationNode))
                {
                    var throwExpression = (TExpressionSyntax)syntaxGenerator
                        .ThrowExpression(inlineMethodContext.InlineExpression)
                        .WithTriviaFrom(calleeInvocationNode);
                    return (calleeInvocationNode, throwExpression.WithTriviaFrom(calleeInvocationNode));
                }
            }
 
            var inlineExpression = inlineMethodContext.InlineExpression;
            if (!_syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent)
                && !calleeMethodSymbol.ReturnsVoid
                && !_syntaxFacts.IsThrowExpression(inlineMethodContext.InlineExpression))
            {
                // Add type cast and parenthesis to the inline expression.
                // It is required to cover cases like:
                // Case 1 (parenthesis added):
                // Before:
                // void Caller() { var x = 3 * Callee(); }
                // int Callee() { return 1 + 2; }
                //
                // After
                // void Caller() { var x = 3 * (1 + 2); }
                // int Callee() { return 1 + 2; }
                //
                // Case 2 (type cast)
                // Before:
                // void Caller() { var x = Callee(); }
                // long Callee() { return 1 }
                //
                // After
                // void Caller() { var x = (long)1; }
                // int Callee() { return 1; }
                //
                // Case 3 (type cast & additional parenthesis)
                // Before:
                // void Caller() { var x = Callee()(); }
                // Func<int> Callee() { return () => 1; }
                // After:
                // void Caller() { var x = ((Func<int>)(() => 1))(); }
                // Func<int> Callee() { return () => 1; }
                //
                // Also, ensure that the node is formatted properly at the destination location. This is needed as the
                // location of the destination node might be very different (indentation/nesting wise) from the original
                // method where the inlined code is coming from.
                inlineExpression = (TExpressionSyntax)syntaxGenerator.AddParentheses(
                    syntaxGenerator.CastExpression(
                        GenerateTypeSyntax(calleeMethodSymbol.ReturnType, allowVar: false),
                        syntaxGenerator.AddParentheses(inlineMethodContext.InlineExpression.WithAdditionalAnnotations(Formatter.Annotation))));
 
            }
 
            return (calleeInvocationNode, inlineExpression.WithTriviaFrom(calleeInvocationNode));
        }
    }
 
    private ISymbol? GetCallerSymbol(
        TInvocationSyntax calleeMethodInvocationNode,
        SemanticModel semanticModel,
        CancellationToken cancellationToken)
    {
        for (SyntaxNode? node = calleeMethodInvocationNode; node != null; node = node.Parent)
        {
            var declaredSymbol = semanticModel.GetDeclaredSymbol(node, cancellationToken);
            if (declaredSymbol?.Kind is SymbolKind.Property or SymbolKind.Method or SymbolKind.Event)
                return declaredSymbol;
 
            if (IsFieldDeclarationSyntax(node))
            {
                foreach (var declarator in node.DescendantNodes().OfType<SyntaxNode>()
                    .Where(n => _syntaxFacts.IsVariableDeclarator(n)))
                {
                    var initializer = _syntaxFacts.GetInitializerOfVariableDeclarator(declarator);
                    if (initializer?.DescendantNodesAndSelf().Contains(calleeMethodInvocationNode) is true &&
                        semanticModel.GetDeclaredSymbol(declarator, cancellationToken) is IFieldSymbol fieldSymbol)
                    {
                        return fieldSymbol;
                    }
                }
 
                // Fall back to the current approach for the VB case
                if (semanticModel.GetAllDeclaredSymbols(node, cancellationToken).SingleOrDefault() is IFieldSymbol fieldSymbolFallBack)
                    return fieldSymbolFallBack;
            }
 
            if (_syntaxFacts.IsAnonymousFunctionExpression(node))
                return semanticModel.GetSymbolInfo(node, cancellationToken).Symbol;
        }
 
        return null;
    }
}