File: InlineMethod\AbstractInlineMethodRefactoringProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.InlineMethod;
 
internal abstract partial class AbstractInlineMethodRefactoringProvider<
        TMethodDeclarationSyntax,
        TStatementSyntax,
        TExpressionSyntax,
        TInvocationSyntax>
    : CodeRefactoringProvider
    where TMethodDeclarationSyntax : SyntaxNode
    where TStatementSyntax : SyntaxNode
    where TExpressionSyntax : SyntaxNode
    where TInvocationSyntax : TExpressionSyntax
{
    /// <summary>
    /// A preferred name used to generated a declaration when the
    /// inline method's body is not a valid expression in ExpressionStatement
    /// Example:
    /// void Caller()
    /// {
    ///     Callee();
    /// }
    /// int Callee()
    /// {
    ///     return 1;
    /// };
    /// After it should be:
    /// void Caller()
    /// {
    ///     int temp = 1;
    /// }
    /// int Callee()
    /// {
    ///     return 1;
    /// };
    /// '1' is not a valid expression in ExpressionStatement so a declaration is needed to be generated.
    /// </summary>
    private const string TemporaryName = "temp";
 
    private readonly ISyntaxFacts _syntaxFacts;
    private readonly ISemanticFactsService _semanticFactsService;
 
    protected abstract TExpressionSyntax? GetRawInlineExpression(TMethodDeclarationSyntax calleeMethodDeclarationSyntaxNode);
    protected abstract SyntaxNode GenerateTypeSyntax(ITypeSymbol symbol, bool allowVar);
    protected abstract TExpressionSyntax GenerateLiteralExpression(ITypeSymbol typeSymbol, object? value);
    protected abstract bool IsFieldDeclarationSyntax(SyntaxNode node);
 
    /// <summary>
    /// Check if <paramref name="expressionNode"/> could be used as an Expression in ExpressionStatement
    /// </summary>
    protected abstract bool IsValidExpressionUnderExpressionStatement(TExpressionSyntax expressionNode);
 
    /// <summary>
    /// Check if <paramref name="syntaxNode"/> could be replaced by ThrowExpression.
    /// For VB it always return false because ThrowExpression doesn't exist.
    /// </summary>
    protected abstract bool CanBeReplacedByThrowExpression(SyntaxNode syntaxNode);
 
    protected AbstractInlineMethodRefactoringProvider(
        ISyntaxFacts syntaxFacts,
        ISemanticFactsService semanticFactsService)
    {
        _syntaxFacts = syntaxFacts;
        _semanticFactsService = semanticFactsService;
    }
 
    public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, _, cancellationToken) = context;
        var calleeInvocationNode = await context.TryGetRelevantNodeAsync<TInvocationSyntax>().ConfigureAwait(false);
        if (calleeInvocationNode == null)
        {
            return;
        }
 
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var calleeMethodSymbol = semanticModel.GetSymbolInfo(calleeInvocationNode, cancellationToken).GetAnySymbol() as IMethodSymbol;
        if (calleeMethodSymbol == null)
        {
            return;
        }
 
        if (calleeMethodSymbol.PartialImplementationPart != null)
        {
            calleeMethodSymbol = calleeMethodSymbol.PartialImplementationPart;
        }
 
        if (!calleeMethodSymbol.IsOrdinaryMethod() && !calleeMethodSymbol.IsExtensionMethod)
        {
            return;
        }
 
        if (calleeMethodSymbol.IsVararg)
        {
            return;
        }
 
        if (calleeMethodSymbol.DeclaredAccessibility != Accessibility.Private)
        {
            return;
        }
 
        var symbolDeclarationService = document.GetRequiredLanguageService<ISymbolDeclarationService>();
        var calleeMethodDeclarationSyntaxReferences = symbolDeclarationService.GetDeclarations(calleeMethodSymbol);
        if (calleeMethodDeclarationSyntaxReferences.Length != 1)
        {
            return;
        }
 
        var calleeMethodDeclarationSyntaxReference = calleeMethodDeclarationSyntaxReferences[0];
        var calleeMethodNode = await calleeMethodDeclarationSyntaxReference.GetSyntaxAsync(cancellationToken).ConfigureAwait(false) as TMethodDeclarationSyntax;
        if (calleeMethodNode == null)
        {
            return;
        }
 
        var inlineExpression = GetRawInlineExpression(calleeMethodNode);
        // Special case 1: AwaitExpression
        if (_syntaxFacts.IsAwaitExpression(inlineExpression))
        {
            // 1. If Caller & callee both have 'await' make sure there is no duplicate 'await'
            // Example:
            // Before:
            // async Task Caller() => await Callee();
            // async Task Callee() => await Task.CompletedTask;
            // After:
            // async Task Caller() => await Task.CompletedTask;
            // async Task Callee() => await Task.CompletedTask;
            // The original inline expression in callee will be 'await Task.CompletedTask'
            // The caller just need 'Task.CompletedTask' without the 'await'
            //
            // 2. If Caller doesn't have await but callee has.
            // Example:
            // Before:
            // void Caller() { Callee().Wait();}
            // async Task Callee() => await DoAsync();
            // After:
            // void Caller() { DoAsync().Wait(); }
            // async Task Callee() => await DoAsync();
            // What caller is expecting is an expression returns 'Task', which doesn't include the 'await'
            inlineExpression = _syntaxFacts.GetExpressionOfAwaitExpression(inlineExpression) as TExpressionSyntax;
        }
 
        if (inlineExpression == null)
        {
            return;
        }
 
        // Special case 2: ThrowStatement & ThrowExpresion
        if (_syntaxFacts.IsThrowStatement(inlineExpression.Parent) || _syntaxFacts.IsThrowExpression(inlineExpression))
        {
            // If this is a throw statement, then it should be valid for
            // 1. If it is invoked as ExpressionStatement
            // Example:
            // Before:
            // void Caller() { Callee(); }
            // void Callee() { throw new Exception();}
            // After:
            // void Caller() { throw new Exception(); }
            // void Callee() { throw new Exception();}
            // 2. If it is invoked in a place allow throw expression
            // Example:
            // Before:
            // void Caller(bool flag) { var x = flag ? Callee() : 1; }
            // int Callee() { throw new Exception();}
            // After:
            // void Caller() { var x = flag ? throw new Exception() : 1; }
            // int Callee() { throw new Exception();}
            // Note here throw statement is changed to throw expression after inlining
            // If this is a throw expression, the check is the same
            // 1. If it is invoked as ExpressionStatement
            // Example:
            // Before:
            // void Caller() { Callee(); }
            // void Callee() => throw new Exception();
            // After:
            // void Caller() { throw new Exception(); }
            // void Callee() => throw new Exception();
            // Note here throw expression is converted to throw statement
            // 2. If it is invoked in a place allow throw expression
            // Example:
            // Before:
            // void Caller(bool flag) { var x = flag ? Callee() : 1; }
            // int Callee() => throw new Exception();
            // After:
            // void Caller() { var x = flag ? throw new Exception() : 1; }
            // int Callee() => throw new Exception();
            if (!CanBeReplacedByThrowExpression(calleeInvocationNode)
                && !_syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent))
            {
                return;
            }
        }
 
        var callerSymbol = GetCallerSymbol(calleeInvocationNode, semanticModel, cancellationToken);
        if (callerSymbol == null)
        {
            return;
        }
 
        var callerReferences = symbolDeclarationService.GetDeclarations(callerSymbol);
        if (callerReferences.Length != 1)
        {
            return;
        }
 
        var callerDeclarationNode = await callerReferences[0].GetSyntaxAsync(cancellationToken).ConfigureAwait(false);
        var invocationOperation = semanticModel.GetOperation(calleeInvocationNode, cancellationToken) as IInvocationOperation;
        if (invocationOperation == null)
        {
            return;
        }
 
        var codeActions = GenerateCodeActions(
            document,
            calleeInvocationNode,
            calleeMethodSymbol,
            calleeMethodNode,
            callerSymbol,
            callerDeclarationNode,
            inlineExpression, invocationOperation);
 
        var nestedCodeAction = CodeAction.Create(
            string.Format(FeaturesResources.Inline_0, calleeMethodSymbol.ToNameDisplayString()),
            codeActions,
            isInlinable: true);
 
        context.RegisterRefactoring(nestedCodeAction, calleeInvocationNode.Span);
    }
 
    private ImmutableArray<CodeAction> GenerateCodeActions(Document document,
        TInvocationSyntax calleeMethodInvocationNode,
        IMethodSymbol calleeMethodSymbol,
        TMethodDeclarationSyntax calleeMethodNode,
        ISymbol callerSymbol,
        SyntaxNode callerMethodNode,
        TExpressionSyntax inlineExpression,
        IInvocationOperation invocationOperation)
    {
        var calleeMethodName = calleeMethodSymbol.ToNameDisplayString();
        var codeActionRemovesCallee = CodeAction.Create(
            string.Format(FeaturesResources.Inline_0, calleeMethodName),
            cancellationToken =>
                InlineMethodAsync(
                    document,
                    calleeMethodInvocationNode,
                    calleeMethodSymbol,
                    calleeMethodNode,
                    callerSymbol,
                    callerMethodNode,
                    inlineExpression,
                    invocationOperation,
                    removeCalleeDeclarationNode: true, cancellationToken: cancellationToken),
            nameof(FeaturesResources.Inline_0) + "_" + calleeMethodName);
 
        var codeActionKeepsCallee = CodeAction.Create(
            string.Format(FeaturesResources.Inline_and_keep_0, calleeMethodName),
            cancellationToken =>
                InlineMethodAsync(
                    document,
                    calleeMethodInvocationNode,
                    calleeMethodSymbol,
                    calleeMethodNode,
                    callerSymbol,
                    callerMethodNode,
                    inlineExpression,
                    invocationOperation,
                    removeCalleeDeclarationNode: false, cancellationToken: cancellationToken),
            nameof(FeaturesResources.Inline_and_keep_0) + "_" + calleeMethodName);
 
        return [codeActionRemovesCallee, codeActionKeepsCallee];
    }
 
    private async Task<Solution> InlineMethodAsync(Document document,
        TInvocationSyntax calleeInvocationNode,
        IMethodSymbol calleeMethodSymbol,
        TMethodDeclarationSyntax calleeMethodNode,
        ISymbol callerSymbol,
        SyntaxNode callerNode,
        TExpressionSyntax rawInlineExpression,
        IInvocationOperation invocationOperation,
        bool removeCalleeDeclarationNode,
        CancellationToken cancellationToken)
    {
        // Find the statement contains the invocation. This should happen when Callee is invoked in a block
        // example:
        // void Caller()
        // {
        //     Action a = () =>
        //     {
        //         var x = Callee();
        //     }
        // } (Local declaration x is the containing node)
        // Note: Stop the searching when it hits lambda or local function, because for this case below don't
        // treat the declaration of a is the containing node
        // void Caller()
        // {
        //     Action a = () => Callee();
        // }
        // it could be null if the caller is invoked as arrow function
        var statementContainsInvocation = calleeInvocationNode.GetAncestors()
            .TakeWhile(node => !_syntaxFacts.IsAnonymousFunctionExpression(node) && !_syntaxFacts.IsLocalFunctionStatement(node))
            .FirstOrDefault(node => node is TStatementSyntax) as TStatementSyntax;
 
        var methodParametersInfo = await GetMethodParametersInfoAsync(
            document,
            calleeInvocationNode,
            calleeMethodNode,
            statementContainsInvocation,
            rawInlineExpression,
            invocationOperation, cancellationToken).ConfigureAwait(false);
 
        var inlineContext = await GetInlineMethodContextAsync(
            document,
            calleeMethodNode,
            calleeInvocationNode,
            calleeMethodSymbol,
            rawInlineExpression,
            methodParametersInfo,
            cancellationToken).ConfigureAwait(false);
 
        var solution = document.Project.Solution;
        var solutionEditor = new SolutionEditor(solution);
        if (removeCalleeDeclarationNode)
        {
            var calleeDocumentId = solution.GetDocumentId(calleeMethodNode.SyntaxTree);
            if (calleeDocumentId != null)
            {
                var calleeDocumentEditor = await solutionEditor.GetDocumentEditorAsync(calleeDocumentId, cancellationToken).ConfigureAwait(false);
                calleeDocumentEditor.RemoveNode(calleeMethodNode);
            }
        }
 
        var newCallerMethodNode = await GetChangedCallerAsync(
            document,
            calleeInvocationNode,
            calleeMethodSymbol,
            callerSymbol,
            callerNode,
            statementContainsInvocation,
            rawInlineExpression,
            methodParametersInfo, inlineContext, cancellationToken).ConfigureAwait(false);
 
        var callerDocumentEditor = await solutionEditor.GetDocumentEditorAsync(document.Id, cancellationToken).ConfigureAwait(false);
        callerDocumentEditor.ReplaceNode(callerNode, newCallerMethodNode);
        return solutionEditor.GetChangedSolution();
    }
 
    private async Task<SyntaxNode> GetChangedCallerAsync(Document document,
        TInvocationSyntax calleeInvocationNode,
        IMethodSymbol calleeMethodSymbol,
        ISymbol callerSymbol,
        SyntaxNode callerDeclarationNode,
        TStatementSyntax? statementContainsInvocation,
        TExpressionSyntax rawInlineExpression,
        MethodParametersInfo methodParametersInfo,
        InlineMethodContext inlineMethodContext,
        CancellationToken cancellationToken)
    {
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var syntaxGenerator = SyntaxGenerator.GetGenerator(document);
        var callerNodeEditor = new SyntaxEditor(callerDeclarationNode, syntaxGenerator);
 
        if (inlineMethodContext.ContainsAwaitExpression)
        {
            // If the inline content has 'await' expression, then make sure the caller is changed to 'async' method
            // if its return type is awaitable. In all other cases, do nothing.
            if (callerSymbol is IMethodSymbol callerMethodSymbol
                && callerMethodSymbol.IsOrdinaryMethod()
                && !callerMethodSymbol.IsAsync
                && (callerMethodSymbol.ReturnsVoid
                    || callerMethodSymbol.IsAwaitableNonDynamic(semanticModel, callerDeclarationNode.SpanStart)))
            {
                var declarationModifiers = DeclarationModifiers.From(callerSymbol).WithAsync(true);
                callerNodeEditor.SetModifiers(callerDeclarationNode, declarationModifiers);
            }
        }
 
        if (statementContainsInvocation != null)
        {
            foreach (var statement in inlineMethodContext.StatementsToInsertBeforeInvocationOfCallee)
            {
                // Add a CarriageReturn to make sure for VB the statement would be in different line.
                callerNodeEditor.InsertBefore(statementContainsInvocation,
                    statement.WithAppendedTrailingTrivia(_syntaxFacts.ElasticCarriageReturnLineFeed));
            }
        }
 
        var (nodeToReplace, inlineNode) = GetInlineNode(
            calleeInvocationNode,
            calleeMethodSymbol,
            statementContainsInvocation,
            rawInlineExpression,
            methodParametersInfo,
            inlineMethodContext,
            semanticModel,
            syntaxGenerator, cancellationToken);
        callerNodeEditor.ReplaceNode(nodeToReplace, (node, generator) => inlineNode);
 
        return callerNodeEditor.GetChangedRoot();
    }
 
    private (SyntaxNode nodeToReplace, SyntaxNode inlineNode) GetInlineNode(
        TInvocationSyntax calleeInvocationNode,
        IMethodSymbol calleeMethodSymbol,
        TStatementSyntax? statementContainsInvocation,
        TExpressionSyntax rawInlineExpression,
        MethodParametersInfo methodParametersInfo,
        InlineMethodContext inlineMethodContext,
        SemanticModel semanticModel,
        SyntaxGenerator syntaxGenerator,
        CancellationToken cancellationToken)
    {
        if (statementContainsInvocation != null)
        {
            if (methodParametersInfo.MergeInlineContentAndVariableDeclarationArgument)
            {
                var rightHandSideValue = _syntaxFacts.GetRightHandSideOfAssignment(inlineMethodContext.InlineExpression);
                var (parameterSymbol, name) = methodParametersInfo.ParametersWithVariableDeclarationArgument.Single();
                var declarationNode = (TStatementSyntax)syntaxGenerator
                    .LocalDeclarationStatement(parameterSymbol.Type, name, rightHandSideValue);
                return (statementContainsInvocation, declarationNode.WithTriviaFrom(statementContainsInvocation));
            }
 
            if (_syntaxFacts.IsThrowStatement(rawInlineExpression.Parent)
                && _syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent))
            {
                var throwStatement = (TStatementSyntax)syntaxGenerator
                    .ThrowStatement(inlineMethodContext.InlineExpression);
                return (statementContainsInvocation, throwStatement.WithTriviaFrom(statementContainsInvocation));
            }
 
            if (_syntaxFacts.IsThrowExpression(rawInlineExpression)
                && _syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent))
            {
                // Example:
                // Before:
                // void Caller() { Callee(); }
                // void Callee() => throw new Exception();
                // After:
                // void Caller() { throw new Exception(); }
                // void Callee() => throw new Exception();
                // Note: Throw expression is converted to throw statement
                var throwStatement = (TStatementSyntax)syntaxGenerator
                    .ThrowStatement(_syntaxFacts.GetExpressionOfThrowExpression(inlineMethodContext.InlineExpression));
                return (statementContainsInvocation, throwStatement.WithTriviaFrom(statementContainsInvocation));
            }
 
            if (_syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent)
                && !calleeMethodSymbol.ReturnsVoid
                && !IsValidExpressionUnderExpressionStatement(inlineMethodContext.InlineExpression))
            {
                // If the callee is invoked as ExpressionStatement, but the inlined expression in the callee can't be
                // placed under ExpressionStatement
                // Example:
                // void Caller()
                // {
                //     Callee();
                // }
                // int Callee()
                // {
                //     return 1;
                // };
                // After it should be:
                // void Caller()
                // {
                //     int temp = 1;
                // }
                // int Callee()
                // {
                //     return 1;
                // };
                // One variable declaration needs to be generated.
                var unusedLocalName =
                    _semanticFactsService.GenerateUniqueLocalName(
                        semanticModel,
                        calleeInvocationNode,
                        container: null,
                        TemporaryName,
                        cancellationToken);
 
                var localDeclarationNode = (TStatementSyntax)syntaxGenerator
                    .LocalDeclarationStatement(calleeMethodSymbol.ReturnType, unusedLocalName.Text,
                        inlineMethodContext.InlineExpression);
                return (statementContainsInvocation, localDeclarationNode.WithTriviaFrom(statementContainsInvocation));
            }
        }
 
        if (_syntaxFacts.IsThrowStatement(rawInlineExpression.Parent))
        {
            // Example:
            // Before:
            // void Caller() => Callee();
            // void Callee() { throw new Exception(); }
            // After:
            // void Caller() => throw new Exception();
            // void Callee() { throw new Exception(); }
            // Note: Throw statement is converted to throw expression
            if (CanBeReplacedByThrowExpression(calleeInvocationNode))
            {
                var throwExpression = (TExpressionSyntax)syntaxGenerator
                    .ThrowExpression(inlineMethodContext.InlineExpression)
                    .WithTriviaFrom(calleeInvocationNode);
                return (calleeInvocationNode, throwExpression.WithTriviaFrom(calleeInvocationNode));
            }
        }
 
        var inlineExpression = inlineMethodContext.InlineExpression;
        if (!_syntaxFacts.IsExpressionStatement(calleeInvocationNode.Parent)
            && !calleeMethodSymbol.ReturnsVoid
            && !_syntaxFacts.IsThrowExpression(inlineMethodContext.InlineExpression))
        {
            // Add type cast and parenthesis to the inline expression.
            // It is required to cover cases like:
            // Case 1 (parenthesis added):
            // Before:
            // void Caller() { var x = 3 * Callee(); }
            // int Callee() { return 1 + 2; }
            //
            // After
            // void Caller() { var x = 3 * (1 + 2); }
            // int Callee() { return 1 + 2; }
            //
            // Case 2 (type cast)
            // Before:
            // void Caller() { var x = Callee(); }
            // long Callee() { return 1 }
            //
            // After
            // void Caller() { var x = (long)1; }
            // int Callee() { return 1; }
            //
            // Case 3 (type cast & additional parenthesis)
            // Before:
            // void Caller() { var x = Callee()(); }
            // Func<int> Callee() { return () => 1; }
            // After:
            // void Caller() { var x = ((Func<int>)(() => 1))(); }
            // Func<int> Callee() { return () => 1; }
            //
            // Also, ensure that the node is formatted properly at the destination location. This is needed as the
            // location of the destination node might be very different (indentation/nesting wise) from the original
            // method where the inlined code is coming from.
            inlineExpression = (TExpressionSyntax)syntaxGenerator.AddParentheses(
                syntaxGenerator.CastExpression(
                    GenerateTypeSyntax(calleeMethodSymbol.ReturnType, allowVar: false),
                    syntaxGenerator.AddParentheses(inlineMethodContext.InlineExpression.WithAdditionalAnnotations(Formatter.Annotation))));
 
        }
 
        return (calleeInvocationNode, inlineExpression.WithTriviaFrom(calleeInvocationNode));
    }
 
    private ISymbol? GetCallerSymbol(
        TInvocationSyntax calleeMethodInvocationNode,
        SemanticModel semanticModel,
        CancellationToken cancellationToken)
    {
        for (SyntaxNode? node = calleeMethodInvocationNode; node != null; node = node.Parent)
        {
            var declaredSymbol = semanticModel.GetDeclaredSymbol(node, cancellationToken);
            if (declaredSymbol.IsKind(SymbolKind.Property)
                || declaredSymbol.IsKind(SymbolKind.Method)
                || declaredSymbol.IsKind(SymbolKind.Event))
            {
                return declaredSymbol;
            }
 
            if (IsFieldDeclarationSyntax(node)
                && semanticModel.GetExistingSymbols(node, cancellationToken).SingleOrDefault() is IFieldSymbol fieldSymbol)
            {
                return fieldSymbol;
            }
 
            if (_syntaxFacts.IsAnonymousFunctionExpression(node))
            {
                return semanticModel.GetSymbolInfo(node, cancellationToken).Symbol;
            }
        }
 
        return null;
    }
}