File: ExtractMethod\CSharpMethodExtractor.CSharpCodeGenerator.CallSiteContainerRewriter.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.LanguageService;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.ExtractMethod;
 
internal sealed partial class CSharpExtractMethodService
{
    internal sealed partial class CSharpMethodExtractor
    {
        private abstract partial class CSharpCodeGenerator
        {
            private sealed class CallSiteContainerRewriter : CSharpSyntaxRewriter
            {
                private readonly SyntaxNode _outmostCallSiteContainer;
                private readonly HashSet<SyntaxAnnotation> _variableToRemoveMap;
                private readonly SyntaxNode _firstStatementOrFieldToReplace;
                private readonly SyntaxNode _lastStatementOrFieldToReplace;
                private readonly ImmutableArray<SyntaxNode> _statementsOrMemberOrAccessorToInsert;
 
                public CallSiteContainerRewriter(
                    SyntaxNode outmostCallSiteContainer,
                    HashSet<SyntaxAnnotation> variableToRemoveMap,
                    SyntaxNode firstStatementOrFieldToReplace,
                    SyntaxNode lastStatementOrFieldToReplace,
                    ImmutableArray<SyntaxNode> statementsOrFieldToInsert)
                {
                    Contract.ThrowIfNull(outmostCallSiteContainer);
                    Contract.ThrowIfNull(variableToRemoveMap);
                    Contract.ThrowIfNull(firstStatementOrFieldToReplace);
                    Contract.ThrowIfNull(lastStatementOrFieldToReplace);
                    Contract.ThrowIfTrue(statementsOrFieldToInsert.IsDefaultOrEmpty);
 
                    _outmostCallSiteContainer = outmostCallSiteContainer;
 
                    _variableToRemoveMap = variableToRemoveMap;
                    _firstStatementOrFieldToReplace = firstStatementOrFieldToReplace;
                    _lastStatementOrFieldToReplace = lastStatementOrFieldToReplace;
                    _statementsOrMemberOrAccessorToInsert = statementsOrFieldToInsert;
                }
 
                public SyntaxNode Generate()
                    => Visit(_outmostCallSiteContainer);
 
                private SyntaxNode ContainerOfStatementsOrFieldToReplace => _firstStatementOrFieldToReplace.Parent;
 
                public override SyntaxNode VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
                {
                    node = (LocalDeclarationStatementSyntax)base.VisitLocalDeclarationStatement(node);
                    var list = new List<VariableDeclaratorSyntax>();
                    var triviaList = new List<SyntaxTrivia>();
                    // go through each var decls in decl statement
                    foreach (var variable in node.Declaration.Variables)
                    {
                        if (_variableToRemoveMap.HasSyntaxAnnotation(variable))
                        {
                            // if it had initialization, it shouldn't reach here.
                            Contract.ThrowIfFalse(variable.Initializer == null);
 
                            // we don't remove trivia around tokens we remove
                            triviaList.AddRange(variable.GetLeadingTrivia());
                            triviaList.AddRange(variable.GetTrailingTrivia());
                            continue;
                        }
 
                        if (triviaList.Count > 0)
                        {
                            list.Add(variable.WithPrependedLeadingTrivia(triviaList));
                            triviaList.Clear();
                            continue;
                        }
 
                        list.Add(variable);
                    }
 
                    if (list.Count == 0)
                    {
                        // nothing has survived. remove this from the list
                        if (triviaList.Count == 0)
                        {
                            return null;
                        }
 
                        // well, there are trivia associated with the node.
                        // we can't just delete the node since then, we will lose
                        // the trivia. unfortunately, it is not easy to attach the trivia
                        // to next token. for now, create an empty statement and associate the
                        // trivia to the statement
 
                        // TODO : think about a way to move the trivia to next token.
                        return SyntaxFactory.EmptyStatement(SyntaxFactory.Token([.. triviaList], SyntaxKind.SemicolonToken, [SyntaxFactory.ElasticMarker]));
                    }
 
                    if (list.Count == node.Declaration.Variables.Count)
                    {
                        // nothing has changed, return as it is
                        return node;
                    }
 
                    // TODO : fix how it manipulate trivia later
 
                    // if there is left over syntax trivia, it will be attached to leading trivia
                    // of semicolon
                    return
                        SyntaxFactory.LocalDeclarationStatement(
                            node.Modifiers,
                            SyntaxFactory.VariableDeclaration(
                                node.Declaration.Type,
                                [.. list]),
                            node.SemicolonToken.WithPrependedLeadingTrivia(triviaList));
                }
 
                // for every kind of extract methods
                public override SyntaxNode VisitBlock(BlockSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        // make sure we visit nodes under the block
                        return base.VisitBlock(node);
                    }
 
                    return node.WithStatements([.. VisitList(ReplaceStatements(node.Statements))]);
                }
 
                public override SyntaxNode VisitSwitchSection(SwitchSectionSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        // make sure we visit nodes under the switch section
                        return base.VisitSwitchSection(node);
                    }
 
                    return node.WithStatements([.. VisitList(ReplaceStatements(node.Statements))]);
                }
 
                // only for single statement or expression
                public override SyntaxNode VisitLabeledStatement(LabeledStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitLabeledStatement(node);
                    }
 
                    return node.WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                public override SyntaxNode VisitElseClause(ElseClauseSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitElseClause(node);
                    }
 
                    return node.WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                public override SyntaxNode VisitIfStatement(IfStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitIfStatement(node);
                    }
 
                    return node.WithCondition(VisitNode(node.Condition))
                               .WithStatement(ReplaceStatementIfNeeded(node.Statement))
                               .WithElse(VisitNode(node.Else));
                }
 
                public override SyntaxNode VisitLockStatement(LockStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitLockStatement(node);
                    }
 
                    return node.WithExpression(VisitNode(node.Expression))
                               .WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                public override SyntaxNode VisitFixedStatement(FixedStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitFixedStatement(node);
                    }
 
                    return node.WithDeclaration(VisitNode(node.Declaration))
                               .WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                public override SyntaxNode VisitUsingStatement(UsingStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitUsingStatement(node);
                    }
 
                    return node.WithDeclaration(VisitNode(node.Declaration))
                               .WithExpression(VisitNode(node.Expression))
                               .WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                public override SyntaxNode VisitForEachStatement(ForEachStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitForEachStatement(node);
                    }
 
                    return node.WithExpression(VisitNode(node.Expression))
                               .WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                public override SyntaxNode VisitForEachVariableStatement(ForEachVariableStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitForEachVariableStatement(node);
                    }
 
                    return node.WithExpression(VisitNode(node.Expression))
                               .WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                public override SyntaxNode VisitForStatement(ForStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitForStatement(node);
                    }
 
                    return node.WithDeclaration(VisitNode(node.Declaration))
                               .WithInitializers(VisitList(node.Initializers))
                               .WithCondition(VisitNode(node.Condition))
                               .WithIncrementors(VisitList(node.Incrementors))
                               .WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                public override SyntaxNode VisitDoStatement(DoStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitDoStatement(node);
                    }
 
                    return node.WithStatement(ReplaceStatementIfNeeded(node.Statement))
                               .WithCondition(VisitNode(node.Condition));
                }
 
                public override SyntaxNode VisitWhileStatement(WhileStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitWhileStatement(node);
                    }
 
                    return node.WithCondition(VisitNode(node.Condition))
                               .WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                private TNode VisitNode<TNode>(TNode node) where TNode : SyntaxNode
                    => (TNode)Visit(node);
 
                private StatementSyntax ReplaceStatementIfNeeded(StatementSyntax statement)
                {
                    Contract.ThrowIfNull(statement);
 
                    // if all three same
                    if ((statement != _firstStatementOrFieldToReplace) || (_firstStatementOrFieldToReplace != _lastStatementOrFieldToReplace))
                        return statement;
 
                    // replace one statement with another
                    if (_statementsOrMemberOrAccessorToInsert.Length == 1)
                        return _statementsOrMemberOrAccessorToInsert.Cast<StatementSyntax>().Single();
 
                    // replace one statement with multiple statements (see bug # 6310)
                    var statements = _statementsOrMemberOrAccessorToInsert.CastArray<StatementSyntax>();
                    return SyntaxFactory.Block(statements);
                }
 
                private SyntaxList<TSyntax> ReplaceList<TSyntax>(SyntaxList<TSyntax> list)
                    where TSyntax : SyntaxNode
                {
                    // okay, this visit contains the statement
                    var newList = new List<TSyntax>(list);
 
                    var firstIndex = newList.FindIndex(s => s == _firstStatementOrFieldToReplace);
                    Contract.ThrowIfFalse(firstIndex >= 0);
 
                    var lastIndex = newList.FindIndex(s => s == _lastStatementOrFieldToReplace);
                    Contract.ThrowIfFalse(lastIndex >= 0);
 
                    Contract.ThrowIfFalse(firstIndex <= lastIndex);
 
                    // remove statement that must be removed
                    newList.RemoveRange(firstIndex, lastIndex - firstIndex + 1);
 
                    // add new statements to replace
                    newList.InsertRange(firstIndex, _statementsOrMemberOrAccessorToInsert.Cast<TSyntax>());
 
                    return [.. newList];
                }
 
                private SyntaxList<StatementSyntax> ReplaceStatements(SyntaxList<StatementSyntax> statements)
                    => ReplaceList(statements);
 
                private SyntaxList<AccessorDeclarationSyntax> ReplaceAccessors(SyntaxList<AccessorDeclarationSyntax> accessors)
                    => ReplaceList(accessors);
 
                private SyntaxList<MemberDeclarationSyntax> ReplaceMembers(SyntaxList<MemberDeclarationSyntax> members, bool global)
                {
                    // okay, this visit contains the statement
                    var newMembers = new List<MemberDeclarationSyntax>(members);
 
                    var firstMemberIndex = newMembers.FindIndex(s => s == (global ? _firstStatementOrFieldToReplace.Parent : _firstStatementOrFieldToReplace));
                    Contract.ThrowIfFalse(firstMemberIndex >= 0);
 
                    var lastMemberIndex = newMembers.FindIndex(s => s == (global ? _lastStatementOrFieldToReplace.Parent : _lastStatementOrFieldToReplace));
                    Contract.ThrowIfFalse(lastMemberIndex >= 0);
 
                    Contract.ThrowIfFalse(firstMemberIndex <= lastMemberIndex);
 
                    // remove statement that must be removed
                    newMembers.RemoveRange(firstMemberIndex, lastMemberIndex - firstMemberIndex + 1);
 
                    // add new statements to replace
                    newMembers.InsertRange(firstMemberIndex,
                        _statementsOrMemberOrAccessorToInsert.Select(s => global ? SyntaxFactory.GlobalStatement((StatementSyntax)s) : (MemberDeclarationSyntax)s));
 
                    return [.. newMembers];
                }
 
                public override SyntaxNode VisitGlobalStatement(GlobalStatementSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitGlobalStatement(node);
                    }
 
                    return node.WithStatement(ReplaceStatementIfNeeded(node.Statement));
                }
 
                public override SyntaxNode VisitInterfaceDeclaration(InterfaceDeclarationSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitInterfaceDeclaration(node);
                    }
 
                    return GetUpdatedTypeDeclaration(node);
                }
 
                public override SyntaxNode VisitConstructorDeclaration(ConstructorDeclarationSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitConstructorDeclaration(node);
                    }
 
                    Contract.ThrowIfFalse(_firstStatementOrFieldToReplace == _lastStatementOrFieldToReplace);
                    return node.WithInitializer((ConstructorInitializerSyntax)_statementsOrMemberOrAccessorToInsert.Single());
                }
 
                public override SyntaxNode VisitClassDeclaration(ClassDeclarationSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitClassDeclaration(node);
                    }
 
                    return GetUpdatedTypeDeclaration(node);
                }
 
                public override SyntaxNode VisitRecordDeclaration(RecordDeclarationSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitRecordDeclaration(node);
                    }
 
                    return GetUpdatedTypeDeclaration(node);
                }
 
                public override SyntaxNode VisitStructDeclaration(StructDeclarationSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitStructDeclaration(node);
                    }
 
                    return GetUpdatedTypeDeclaration(node);
                }
 
                public override SyntaxNode VisitAccessorList(AccessorListSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                    {
                        return base.VisitAccessorList(node);
                    }
 
                    var newAccessors = VisitList(ReplaceAccessors(node.Accessors));
                    return node.WithAccessors(newAccessors);
                }
 
                public override SyntaxNode VisitCompilationUnit(CompilationUnitSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace.Parent)
                    {
                        // make sure we visit nodes under the block
                        return base.VisitCompilationUnit(node);
                    }
 
                    var newMembers = VisitList(ReplaceMembers(node.Members, global: true));
                    return node.WithMembers(newMembers);
                }
 
                public override SyntaxNode VisitBaseList(BaseListSyntax node)
                {
                    if (node != ContainerOfStatementsOrFieldToReplace)
                        return base.VisitBaseList(node);
 
                    var primaryConstructorBase = (PrimaryConstructorBaseTypeSyntax)_statementsOrMemberOrAccessorToInsert.Single();
                    return node.WithTypes(node.Types.Replace(node.Types[0], primaryConstructorBase));
                }
 
                private SyntaxNode GetUpdatedTypeDeclaration(TypeDeclarationSyntax node)
                {
                    var newMembers = VisitList(ReplaceMembers(node.Members, global: false));
                    return node.WithMembers(newMembers);
                }
            }
        }
    }
}