File: IntroduceVariable\CSharpIntroduceVariableService_IntroduceLocal.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.IntroduceVariable;
 
using static CSharpSyntaxTokens;
using static SyntaxFactory;
 
internal sealed partial class CSharpIntroduceVariableService
{
    protected override async Task<Document> IntroduceLocalAsync(
        SemanticDocument document,
        ExpressionSyntax expression,
        bool allOccurrences,
        bool isConstant,
        CancellationToken cancellationToken)
    {
        var containerToGenerateInto = expression.Ancestors().FirstOrDefault(s =>
            s is BlockSyntax or ArrowExpressionClauseSyntax or LambdaExpressionSyntax);
 
        var newLocalNameToken = GenerateUniqueLocalName(
            document, expression, isConstant, containerToGenerateInto, cancellationToken);
        var newLocalName = IdentifierName(newLocalNameToken);
 
        var modifiers = isConstant
            ? TokenList(ConstKeyword)
            : default;
 
        var declarationStatement = LocalDeclarationStatement(
            modifiers,
            VariableDeclaration(
                GetTypeSyntax(document, expression, cancellationToken),
                [VariableDeclarator(
                    newLocalNameToken.WithAdditionalAnnotations(RenameAnnotation.Create()),
                    null,
                    EqualsValueClause(expression.WithoutTrivia()))]));
 
        // If we're inserting into a multi-line parent, then add a newline after the local-var
        // we're adding.  That way we don't end up having it and the starting statement be on
        // the same line (which will cause indentation to be computed incorrectly).
        var text = await document.Document.GetValueTextAsync(cancellationToken).ConfigureAwait(false);
        if (!text.AreOnSameLine(containerToGenerateInto.GetFirstToken(), containerToGenerateInto.GetLastToken()))
        {
            declarationStatement = declarationStatement.WithAppendedTrailingTrivia(ElasticCarriageReturnLineFeed);
        }
 
        switch (containerToGenerateInto)
        {
            case BlockSyntax block:
                return await IntroduceLocalDeclarationIntoBlockAsync(
                    document, block, expression, newLocalName, declarationStatement, allOccurrences, cancellationToken).ConfigureAwait(false);
 
            case ArrowExpressionClauseSyntax arrowExpression:
                // this will be null for expression-bodied properties & indexer (not for individual getters & setters, those do have a symbol),
                // both of which are a shorthand for the getter and always return a value
                var method = document.SemanticModel.GetDeclaredSymbol(arrowExpression.Parent, cancellationToken) as IMethodSymbol;
                var createReturnStatement = true;
 
                if (method is not null)
                    createReturnStatement = !method.ReturnsVoid && !method.IsAsyncReturningVoidTask(document.SemanticModel.Compilation);
 
                return RewriteExpressionBodiedMemberAndIntroduceLocalDeclaration(
                    document, arrowExpression, expression, newLocalName,
                    declarationStatement, allOccurrences, createReturnStatement, cancellationToken);
 
            case LambdaExpressionSyntax lambda:
                return IntroduceLocalDeclarationIntoLambda(
                    document, lambda, expression, newLocalName, declarationStatement,
                    allOccurrences, cancellationToken);
        }
 
        throw new InvalidOperationException();
    }
 
    private Document IntroduceLocalDeclarationIntoLambda(
        SemanticDocument document,
        LambdaExpressionSyntax oldLambda,
        ExpressionSyntax expression,
        IdentifierNameSyntax newLocalName,
        LocalDeclarationStatementSyntax declarationStatement,
        bool allOccurrences,
        CancellationToken cancellationToken)
    {
        var oldBody = (ExpressionSyntax)oldLambda.Body;
        var isEntireLambdaBodySelected = oldBody.Equals(expression.WalkUpParentheses());
 
        var rewrittenBody = Rewrite(
            document, expression, newLocalName, document, oldBody, allOccurrences, cancellationToken);
 
        var shouldIncludeReturnStatement = ShouldIncludeReturnStatement(document, oldLambda, cancellationToken);
        var newBody = GetNewBlockBodyForLambda(
            declarationStatement, isEntireLambdaBodySelected, rewrittenBody, shouldIncludeReturnStatement);
 
        // Add an elastic newline so that the formatter will place this new lambda body across multiple lines.
        newBody = newBody.WithOpenBraceToken(newBody.OpenBraceToken.WithAppendedTrailingTrivia(ElasticCarriageReturnLineFeed))
                         .WithAdditionalAnnotations(Formatter.Annotation);
 
        var newLambda = oldLambda.WithBody(newBody);
 
        var newRoot = document.Root.ReplaceNode(oldLambda, newLambda);
        return document.Document.WithSyntaxRoot(newRoot);
    }
 
    private static bool ShouldIncludeReturnStatement(
        SemanticDocument document,
        LambdaExpressionSyntax oldLambda,
        CancellationToken cancellationToken)
    {
        if (document.SemanticModel.GetTypeInfo(oldLambda, cancellationToken).ConvertedType is INamedTypeSymbol delegateType &&
            delegateType.DelegateInvokeMethod != null)
        {
            if (delegateType.DelegateInvokeMethod.ReturnsVoid)
            {
                return false;
            }
 
            // Async lambdas with a Task or ValueTask return type don't need a return statement.
            // e.g.:
            //     Func<int, Task> f = async x => await M2();
            //
            // After refactoring:
            //     Func<int, Task> f = async x =>
            //     {
            //         Task task = M2();
            //         await task;
            //     };
            var compilation = document.SemanticModel.Compilation;
            var delegateReturnType = delegateType.DelegateInvokeMethod.ReturnType;
            if (oldLambda.AsyncKeyword != default && delegateReturnType != null)
            {
                if ((compilation.TaskType() != null && delegateReturnType.Equals(compilation.TaskType())) ||
                    (compilation.ValueTaskType() != null && delegateReturnType.Equals(compilation.ValueTaskType())))
                {
                    return false;
                }
            }
        }
 
        return true;
    }
 
    private static BlockSyntax GetNewBlockBodyForLambda(
        LocalDeclarationStatementSyntax declarationStatement,
        bool isEntireLambdaBodySelected,
        ExpressionSyntax rewrittenBody,
        bool includeReturnStatement)
    {
        if (includeReturnStatement)
        {
            // Case 1: The lambda has a non-void return type.
            // e.g.:
            //     Func<int, int> f = x => [|x + 1|];
            //
            // After refactoring:
            //     Func<int, int> f = x =>
            //     {
            //         var v = x + 1;
            //         return v;
            //     };
            return Block(declarationStatement, ReturnStatement(rewrittenBody));
        }
 
        // For lambdas with void return types, we don't need to include the rewritten body if the entire lambda body
        // was originally selected for refactoring, as the rewritten body should already be encompassed within the
        // declaration statement.
        if (isEntireLambdaBodySelected)
        {
            // Case 2a: The lambda has a void return type, and the user selects the entire lambda body.
            // e.g.:
            //     Action<int> goo = x => [|x.ToString()|];
            //
            // After refactoring:
            //     Action<int> goo = x =>
            //     {
            //         string v = x.ToString();
            //     };
            return Block(declarationStatement);
        }
 
        // Case 2b: The lambda has a void return type, and the user didn't select the entire lambda body.
        // e.g.:
        //     Task.Run(() => File.Copy("src", [|Path.Combine("dir", "file")|]));
        //
        // After refactoring:
        //     Task.Run(() =>
        //     {
        //         string destFileName = Path.Combine("dir", "file");
        //         File.Copy("src", destFileName);
        //     });
        return Block(
            declarationStatement,
            ExpressionStatement(rewrittenBody, SemicolonToken));
    }
 
    private static TypeSyntax GetTypeSyntax(SemanticDocument document, ExpressionSyntax expression, CancellationToken cancellationToken)
    {
        var typeSymbol = GetTypeSymbol(document, expression, cancellationToken);
        return typeSymbol.GenerateTypeSyntax();
    }
 
    private Document RewriteExpressionBodiedMemberAndIntroduceLocalDeclaration(
        SemanticDocument document,
        ArrowExpressionClauseSyntax arrowExpression,
        ExpressionSyntax expression,
        NameSyntax newLocalName,
        LocalDeclarationStatementSyntax declarationStatement,
        bool allOccurrences,
        bool createReturnStatement,
        CancellationToken cancellationToken)
    {
        var oldBody = arrowExpression;
        var oldParentingNode = oldBody.Parent;
        var leadingTrivia = oldBody.GetLeadingTrivia()
                                   .AddRange(oldBody.ArrowToken.TrailingTrivia);
 
        var newExpression = Rewrite(document, expression, newLocalName, document, oldBody.Expression, allOccurrences, cancellationToken);
 
        var convertedStatement = createReturnStatement
            ? ReturnStatement(newExpression)
            : (StatementSyntax)ExpressionStatement(newExpression);
 
        var newBody = Block(declarationStatement, convertedStatement)
                                   .WithLeadingTrivia(leadingTrivia)
                                   .WithTrailingTrivia(oldBody.GetTrailingTrivia());
 
        // Add an elastic newline so that the formatter will place this new block across multiple lines.
        newBody = newBody.WithOpenBraceToken(newBody.OpenBraceToken.WithAppendedTrailingTrivia(ElasticCarriageReturnLineFeed))
                         .WithAdditionalAnnotations(Formatter.Annotation);
 
        var newRoot = document.Root.ReplaceNode(oldParentingNode, WithBlockBody(oldParentingNode, newBody));
        return document.Document.WithSyntaxRoot(newRoot);
    }
 
    private static SyntaxNode WithBlockBody(SyntaxNode node, BlockSyntax body)
    {
        switch (node)
        {
            case BasePropertyDeclarationSyntax baseProperty:
                var accessorList = AccessorList(
                    [AccessorDeclaration(SyntaxKind.GetAccessorDeclaration, body)]);
                return baseProperty
                    .TryWithExpressionBody(null)
                    .WithAccessorList(accessorList)
                    .TryWithSemicolonToken(Token(SyntaxKind.None))
                    .WithTriviaFrom(baseProperty);
            case AccessorDeclarationSyntax accessor:
                return accessor
                    .WithExpressionBody(null)
                    .WithBody(body)
                    .WithSemicolonToken(Token(SyntaxKind.None))
                    .WithTriviaFrom(accessor);
            case BaseMethodDeclarationSyntax baseMethod:
                return baseMethod
                    .WithExpressionBody(null)
                    .WithBody(body)
                    .WithSemicolonToken(Token(SyntaxKind.None))
                    .WithTriviaFrom(baseMethod);
            case LocalFunctionStatementSyntax localFunction:
                return localFunction
                    .WithExpressionBody(null)
                    .WithBody(body)
                    .WithSemicolonToken(Token(SyntaxKind.None))
                    .WithTriviaFrom(localFunction);
            default:
                throw ExceptionUtilities.UnexpectedValue(node);
        }
    }
 
    private async Task<Document> IntroduceLocalDeclarationIntoBlockAsync(
        SemanticDocument document,
        BlockSyntax block,
        ExpressionSyntax expression,
        NameSyntax newLocalName,
        LocalDeclarationStatementSyntax declarationStatement,
        bool allOccurrences,
        CancellationToken cancellationToken)
    {
        declarationStatement = declarationStatement.WithAdditionalAnnotations(Formatter.Annotation);
 
        SyntaxNode scope = block;
 
        // If we're within a non-static local function, our scope for the new local declaration is expanded to include the enclosing member.
        var localFunction = block.GetAncestor<LocalFunctionStatementSyntax>();
        if (localFunction != null && !localFunction.Modifiers.Any(SyntaxKind.StaticKeyword))
        {
            scope = block.GetAncestor<MemberDeclarationSyntax>();
        }
 
        var matches = FindMatches(document, expression, document, scope, allOccurrences, cancellationToken);
        Debug.Assert(matches.Contains(expression));
 
        (document, matches) = await ComplexifyParentingStatementsAsync(document, matches, cancellationToken).ConfigureAwait(false);
 
        // Our original expression should have been one of the matches, which were tracked as part
        // of complexification, so we can retrieve the latest version of the expression here.
        expression = document.Root.GetCurrentNode(expression);
 
        var root = document.Root;
        ISet<StatementSyntax> allAffectedStatements = new HashSet<StatementSyntax>(matches.SelectMany(expr => GetApplicableStatementAncestors(expr)));
 
        SyntaxNode innermostCommonBlock;
 
        var innermostStatements = new HashSet<StatementSyntax>(matches.Select(expr => GetApplicableStatementAncestors(expr).First()));
        if (innermostStatements.Count == 1)
        {
            // if there was only one match, or all the matches came from the same statement
            var statement = innermostStatements.Single();
 
            // and the statement is an embedded statement without a block, we want to generate one
            // around this statement rather than continue going up to find an actual block
            if (!IsBlockLike(statement.Parent))
            {
                root = root.TrackNodes(allAffectedStatements.Concat(new SyntaxNode[] { expression, statement }));
                root = root.ReplaceNode(root.GetCurrentNode(statement),
                    Block(root.GetCurrentNode(statement)).WithAdditionalAnnotations(Formatter.Annotation));
 
                expression = root.GetCurrentNode(expression);
                allAffectedStatements = allAffectedStatements.Select(root.GetCurrentNode).ToSet();
 
                statement = root.GetCurrentNode(statement);
            }
 
            innermostCommonBlock = statement.Parent;
        }
        else
        {
            innermostCommonBlock = innermostStatements.FindInnermostCommonNode(IsBlockLike);
        }
 
        var firstStatementAffectedIndex = GetFirstStatementAffectedIndex(innermostCommonBlock, matches, GetStatements(innermostCommonBlock).IndexOf(allAffectedStatements.Contains));
 
        var newInnerMostBlock = Rewrite(
            document, expression, newLocalName, document, innermostCommonBlock, allOccurrences, cancellationToken);
 
        var statements = InsertWithinTriviaOfNext(GetStatements(newInnerMostBlock), declarationStatement, firstStatementAffectedIndex);
        var finalInnerMostBlock = WithStatements(newInnerMostBlock, statements);
 
        var newRoot = root.ReplaceNode(innermostCommonBlock, finalInnerMostBlock);
        return document.Document.WithSyntaxRoot(newRoot);
    }
 
    private static IEnumerable<StatementSyntax> GetApplicableStatementAncestors(ExpressionSyntax expr)
    {
        foreach (var statement in expr.GetAncestorsOrThis<StatementSyntax>())
        {
            // When determining where to put a local, we don't want to put it between the `else`
            // and `if` of a compound if-statement.
 
            if (statement.Kind() == SyntaxKind.IfStatement &&
                statement.IsParentKind(SyntaxKind.ElseClause))
            {
                continue;
            }
 
            yield return statement;
        }
    }
 
    private static int GetFirstStatementAffectedIndex(SyntaxNode innermostCommonBlock, ISet<ExpressionSyntax> matches, int firstStatementAffectedIndex)
    {
        // If a local function is involved, we have to make sure the new declaration is placed:
        //     1. Before all calls to local functions that use the variable.
        //     2. Before the local function(s) themselves.
        //     3. Before all matches, i.e. places in the code where the new declaration will replace existing code.
        // Cases (2) and (3) are already covered by the 'firstStatementAffectedIndex' parameter. Thus, all we have to do is ensure we consider (1) when
        // determining where to place our new declaration.
 
        // Find all the local functions within the scope that will use the new declaration.
        var localFunctions = innermostCommonBlock.DescendantNodes().Where(node => node.IsKind(SyntaxKind.LocalFunctionStatement) && matches.Any(match => match.Span.OverlapsWith(node.Span)));
 
        if (localFunctions.IsEmpty())
        {
            return firstStatementAffectedIndex;
        }
 
        var localFunctionIdentifiers = localFunctions.Select(node => ((LocalFunctionStatementSyntax)node).Identifier.ValueText);
 
        // Find all calls to the applicable local functions within the scope.
        var localFunctionCalls = innermostCommonBlock.DescendantNodes().Where(node => node is InvocationExpressionSyntax invocationExpression &&
                                                                              invocationExpression.Expression.GetRightmostName() != null &&
                                                                              !invocationExpression.Expression.IsKind(SyntaxKind.SimpleMemberAccessExpression) &&
                                                                              localFunctionIdentifiers.Contains(invocationExpression.Expression.GetRightmostName().Identifier.ValueText));
 
        if (localFunctionCalls.IsEmpty())
        {
            return firstStatementAffectedIndex;
        }
 
        // Find which call is the earliest.
        var earliestLocalFunctionCall = localFunctionCalls.Min(node => node.SpanStart);
 
        var statementsInBlock = GetStatements(innermostCommonBlock);
 
        // Check if our earliest call is before all local function declarations and all matches, and if so, place our new declaration there.
        var earliestLocalFunctionCallIndex = statementsInBlock.IndexOf(s => s.Span.Contains(earliestLocalFunctionCall));
        return Math.Min(earliestLocalFunctionCallIndex, firstStatementAffectedIndex);
    }
 
    private static SyntaxList<StatementSyntax> InsertWithinTriviaOfNext(
        SyntaxList<StatementSyntax> oldStatements,
        StatementSyntax newStatement,
        int statementIndex)
    {
        var nextStatement = oldStatements.ElementAtOrDefault(statementIndex);
        if (nextStatement == null)
            return oldStatements.Insert(statementIndex, newStatement);
 
        // Grab all the trivia before the line the next statement is on and move it to the new node.
 
        var nextStatementLeading = nextStatement.GetLeadingTrivia();
        var precedingEndOfLine = nextStatementLeading.LastOrDefault(t => t.Kind() == SyntaxKind.EndOfLineTrivia);
        if (precedingEndOfLine == default)
        {
            return oldStatements.ReplaceRange(
                nextStatement, [newStatement, nextStatement]);
        }
 
        var endOfLineIndex = nextStatementLeading.IndexOf(precedingEndOfLine) + 1;
 
        return oldStatements.ReplaceRange(
            nextStatement,
            [
                newStatement.WithLeadingTrivia(nextStatementLeading.Take(endOfLineIndex)),
                nextStatement.WithLeadingTrivia(nextStatementLeading.Skip(endOfLineIndex)),
            ]);
    }
 
    private static bool IsBlockLike(SyntaxNode node) => node is BlockSyntax or SwitchSectionSyntax;
 
    private static SyntaxList<StatementSyntax> GetStatements(SyntaxNode blockLike)
        => blockLike switch
        {
            BlockSyntax block => block.Statements,
            SwitchSectionSyntax switchSection => switchSection.Statements,
            _ => throw ExceptionUtilities.UnexpectedValue(blockLike),
        };
 
    private static SyntaxNode WithStatements(SyntaxNode blockLike, SyntaxList<StatementSyntax> statements)
        => blockLike switch
        {
            BlockSyntax block => block.WithStatements(statements),
            SwitchSectionSyntax switchSection => switchSection.WithStatements(statements),
            _ => throw ExceptionUtilities.UnexpectedValue(blockLike),
        };
}