File: ExtractMethod\CSharpMethodExtractor.CSharpCodeGenerator.CallSiteContainerRewriter.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.LanguageService;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.ExtractMethod;
 
internal sealed partial class CSharpMethodExtractor
{
    private abstract partial class CSharpCodeGenerator
    {
        private sealed class CallSiteContainerRewriter : CSharpSyntaxRewriter
        {
            private readonly SyntaxNode _outmostCallSiteContainer;
            private readonly HashSet<SyntaxAnnotation> _variableToRemoveMap;
            private readonly SyntaxNode _firstStatementOrFieldToReplace;
            private readonly SyntaxNode _lastStatementOrFieldToReplace;
            private readonly ImmutableArray<SyntaxNode> _statementsOrMemberOrAccessorToInsert;
 
            public CallSiteContainerRewriter(
                SyntaxNode outmostCallSiteContainer,
                HashSet<SyntaxAnnotation> variableToRemoveMap,
                SyntaxNode firstStatementOrFieldToReplace,
                SyntaxNode lastStatementOrFieldToReplace,
                ImmutableArray<SyntaxNode> statementsOrFieldToInsert)
            {
                Contract.ThrowIfNull(outmostCallSiteContainer);
                Contract.ThrowIfNull(variableToRemoveMap);
                Contract.ThrowIfNull(firstStatementOrFieldToReplace);
                Contract.ThrowIfNull(lastStatementOrFieldToReplace);
                Contract.ThrowIfTrue(statementsOrFieldToInsert.IsDefaultOrEmpty);
 
                _outmostCallSiteContainer = outmostCallSiteContainer;
 
                _variableToRemoveMap = variableToRemoveMap;
                _firstStatementOrFieldToReplace = firstStatementOrFieldToReplace;
                _lastStatementOrFieldToReplace = lastStatementOrFieldToReplace;
                _statementsOrMemberOrAccessorToInsert = statementsOrFieldToInsert;
 
                Contract.ThrowIfFalse(_firstStatementOrFieldToReplace.Parent == _lastStatementOrFieldToReplace.Parent
                    || CSharpSyntaxFacts.Instance.AreStatementsInSameContainer(_firstStatementOrFieldToReplace, _lastStatementOrFieldToReplace));
            }
 
            public SyntaxNode Generate()
                => Visit(_outmostCallSiteContainer);
 
            private SyntaxNode ContainerOfStatementsOrFieldToReplace => _firstStatementOrFieldToReplace.Parent;
 
            public override SyntaxNode VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
            {
                node = (LocalDeclarationStatementSyntax)base.VisitLocalDeclarationStatement(node);
                var list = new List<VariableDeclaratorSyntax>();
                var triviaList = new List<SyntaxTrivia>();
                // go through each var decls in decl statement
                foreach (var variable in node.Declaration.Variables)
                {
                    if (_variableToRemoveMap.HasSyntaxAnnotation(variable))
                    {
                        // if it had initialization, it shouldn't reach here.
                        Contract.ThrowIfFalse(variable.Initializer == null);
 
                        // we don't remove trivia around tokens we remove
                        triviaList.AddRange(variable.GetLeadingTrivia());
                        triviaList.AddRange(variable.GetTrailingTrivia());
                        continue;
                    }
 
                    if (triviaList.Count > 0)
                    {
                        list.Add(variable.WithPrependedLeadingTrivia(triviaList));
                        triviaList.Clear();
                        continue;
                    }
 
                    list.Add(variable);
                }
 
                if (list.Count == 0)
                {
                    // nothing has survived. remove this from the list
                    if (triviaList.Count == 0)
                    {
                        return null;
                    }
 
                    // well, there are trivia associated with the node.
                    // we can't just delete the node since then, we will lose
                    // the trivia. unfortunately, it is not easy to attach the trivia
                    // to next token. for now, create an empty statement and associate the
                    // trivia to the statement
 
                    // TODO : think about a way to move the trivia to next token.
                    return SyntaxFactory.EmptyStatement(SyntaxFactory.Token([.. triviaList], SyntaxKind.SemicolonToken, [SyntaxFactory.ElasticMarker]));
                }
 
                if (list.Count == node.Declaration.Variables.Count)
                {
                    // nothing has changed, return as it is
                    return node;
                }
 
                // TODO : fix how it manipulate trivia later
 
                // if there is left over syntax trivia, it will be attached to leading trivia
                // of semicolon
                return
                    SyntaxFactory.LocalDeclarationStatement(
                        node.Modifiers,
                        SyntaxFactory.VariableDeclaration(
                            node.Declaration.Type,
                            [.. list]),
                        node.SemicolonToken.WithPrependedLeadingTrivia(triviaList));
            }
 
            // for every kind of extract methods
            public override SyntaxNode VisitBlock(BlockSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    // make sure we visit nodes under the block
                    return base.VisitBlock(node);
                }
 
                return node.WithStatements([.. VisitList(ReplaceStatements(node.Statements))]);
            }
 
            public override SyntaxNode VisitSwitchSection(SwitchSectionSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    // make sure we visit nodes under the switch section
                    return base.VisitSwitchSection(node);
                }
 
                return node.WithStatements([.. VisitList(ReplaceStatements(node.Statements))]);
            }
 
            // only for single statement or expression
            public override SyntaxNode VisitLabeledStatement(LabeledStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitLabeledStatement(node);
                }
 
                return node.WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            public override SyntaxNode VisitElseClause(ElseClauseSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitElseClause(node);
                }
 
                return node.WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            public override SyntaxNode VisitIfStatement(IfStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitIfStatement(node);
                }
 
                return node.WithCondition(VisitNode(node.Condition))
                           .WithStatement(ReplaceStatementIfNeeded(node.Statement))
                           .WithElse(VisitNode(node.Else));
            }
 
            public override SyntaxNode VisitLockStatement(LockStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitLockStatement(node);
                }
 
                return node.WithExpression(VisitNode(node.Expression))
                           .WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            public override SyntaxNode VisitFixedStatement(FixedStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitFixedStatement(node);
                }
 
                return node.WithDeclaration(VisitNode(node.Declaration))
                           .WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            public override SyntaxNode VisitUsingStatement(UsingStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitUsingStatement(node);
                }
 
                return node.WithDeclaration(VisitNode(node.Declaration))
                           .WithExpression(VisitNode(node.Expression))
                           .WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            public override SyntaxNode VisitForEachStatement(ForEachStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitForEachStatement(node);
                }
 
                return node.WithExpression(VisitNode(node.Expression))
                           .WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            public override SyntaxNode VisitForEachVariableStatement(ForEachVariableStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitForEachVariableStatement(node);
                }
 
                return node.WithExpression(VisitNode(node.Expression))
                           .WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            public override SyntaxNode VisitForStatement(ForStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitForStatement(node);
                }
 
                return node.WithDeclaration(VisitNode(node.Declaration))
                           .WithInitializers(VisitList(node.Initializers))
                           .WithCondition(VisitNode(node.Condition))
                           .WithIncrementors(VisitList(node.Incrementors))
                           .WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            public override SyntaxNode VisitDoStatement(DoStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitDoStatement(node);
                }
 
                return node.WithStatement(ReplaceStatementIfNeeded(node.Statement))
                           .WithCondition(VisitNode(node.Condition));
            }
 
            public override SyntaxNode VisitWhileStatement(WhileStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitWhileStatement(node);
                }
 
                return node.WithCondition(VisitNode(node.Condition))
                           .WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            private TNode VisitNode<TNode>(TNode node) where TNode : SyntaxNode
                => (TNode)Visit(node);
 
            private StatementSyntax ReplaceStatementIfNeeded(StatementSyntax statement)
            {
                Contract.ThrowIfNull(statement);
 
                // if all three same
                if ((statement != _firstStatementOrFieldToReplace) || (_firstStatementOrFieldToReplace != _lastStatementOrFieldToReplace))
                    return statement;
 
                // replace one statement with another
                if (_statementsOrMemberOrAccessorToInsert.Length == 1)
                    return _statementsOrMemberOrAccessorToInsert.Cast<StatementSyntax>().Single();
 
                // replace one statement with multiple statements (see bug # 6310)
                var statements = _statementsOrMemberOrAccessorToInsert.CastArray<StatementSyntax>();
                return SyntaxFactory.Block(statements);
            }
 
            private SyntaxList<TSyntax> ReplaceList<TSyntax>(SyntaxList<TSyntax> list)
                where TSyntax : SyntaxNode
            {
                // okay, this visit contains the statement
                var newList = new List<TSyntax>(list);
 
                var firstIndex = newList.FindIndex(s => s == _firstStatementOrFieldToReplace);
                Contract.ThrowIfFalse(firstIndex >= 0);
 
                var lastIndex = newList.FindIndex(s => s == _lastStatementOrFieldToReplace);
                Contract.ThrowIfFalse(lastIndex >= 0);
 
                Contract.ThrowIfFalse(firstIndex <= lastIndex);
 
                // remove statement that must be removed
                newList.RemoveRange(firstIndex, lastIndex - firstIndex + 1);
 
                // add new statements to replace
                newList.InsertRange(firstIndex, _statementsOrMemberOrAccessorToInsert.Cast<TSyntax>());
 
                return [.. newList];
            }
 
            private SyntaxList<StatementSyntax> ReplaceStatements(SyntaxList<StatementSyntax> statements)
                => ReplaceList(statements);
 
            private SyntaxList<AccessorDeclarationSyntax> ReplaceAccessors(SyntaxList<AccessorDeclarationSyntax> accessors)
                => ReplaceList(accessors);
 
            private SyntaxList<MemberDeclarationSyntax> ReplaceMembers(SyntaxList<MemberDeclarationSyntax> members, bool global)
            {
                // okay, this visit contains the statement
                var newMembers = new List<MemberDeclarationSyntax>(members);
 
                var firstMemberIndex = newMembers.FindIndex(s => s == (global ? _firstStatementOrFieldToReplace.Parent : _firstStatementOrFieldToReplace));
                Contract.ThrowIfFalse(firstMemberIndex >= 0);
 
                var lastMemberIndex = newMembers.FindIndex(s => s == (global ? _lastStatementOrFieldToReplace.Parent : _lastStatementOrFieldToReplace));
                Contract.ThrowIfFalse(lastMemberIndex >= 0);
 
                Contract.ThrowIfFalse(firstMemberIndex <= lastMemberIndex);
 
                // remove statement that must be removed
                newMembers.RemoveRange(firstMemberIndex, lastMemberIndex - firstMemberIndex + 1);
 
                // add new statements to replace
                newMembers.InsertRange(firstMemberIndex,
                    _statementsOrMemberOrAccessorToInsert.Select(s => global ? SyntaxFactory.GlobalStatement((StatementSyntax)s) : (MemberDeclarationSyntax)s));
 
                return [.. newMembers];
            }
 
            public override SyntaxNode VisitGlobalStatement(GlobalStatementSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitGlobalStatement(node);
                }
 
                return node.WithStatement(ReplaceStatementIfNeeded(node.Statement));
            }
 
            public override SyntaxNode VisitInterfaceDeclaration(InterfaceDeclarationSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitInterfaceDeclaration(node);
                }
 
                return GetUpdatedTypeDeclaration(node);
            }
 
            public override SyntaxNode VisitConstructorDeclaration(ConstructorDeclarationSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitConstructorDeclaration(node);
                }
 
                Contract.ThrowIfFalse(_firstStatementOrFieldToReplace == _lastStatementOrFieldToReplace);
                return node.WithInitializer((ConstructorInitializerSyntax)_statementsOrMemberOrAccessorToInsert.Single());
            }
 
            public override SyntaxNode VisitClassDeclaration(ClassDeclarationSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitClassDeclaration(node);
                }
 
                return GetUpdatedTypeDeclaration(node);
            }
 
            public override SyntaxNode VisitRecordDeclaration(RecordDeclarationSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitRecordDeclaration(node);
                }
 
                return GetUpdatedTypeDeclaration(node);
            }
 
            public override SyntaxNode VisitStructDeclaration(StructDeclarationSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitStructDeclaration(node);
                }
 
                return GetUpdatedTypeDeclaration(node);
            }
 
            public override SyntaxNode VisitAccessorList(AccessorListSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace)
                {
                    return base.VisitAccessorList(node);
                }
 
                var newAccessors = VisitList(ReplaceAccessors(node.Accessors));
                return node.WithAccessors(newAccessors);
            }
 
            public override SyntaxNode VisitCompilationUnit(CompilationUnitSyntax node)
            {
                if (node != ContainerOfStatementsOrFieldToReplace.Parent)
                {
                    // make sure we visit nodes under the block
                    return base.VisitCompilationUnit(node);
                }
 
                var newMembers = VisitList(ReplaceMembers(node.Members, global: true));
                return node.WithMembers(newMembers);
            }
 
            private SyntaxNode GetUpdatedTypeDeclaration(TypeDeclarationSyntax node)
            {
                var newMembers = VisitList(ReplaceMembers(node.Members, global: false));
                return node.WithMembers(newMembers);
            }
        }
    }
}