File: Syntax\SyntaxReplacer.cs
Web Access
Project: src\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.Syntax
{
    internal static class SyntaxReplacer
    {
        internal static SyntaxNode Replace<TNode>(
            SyntaxNode root,
            IEnumerable<TNode>? nodes = null,
            Func<TNode, TNode, SyntaxNode>? computeReplacementNode = null,
            IEnumerable<SyntaxToken>? tokens = null,
            Func<SyntaxToken, SyntaxToken, SyntaxToken>? computeReplacementToken = null,
            IEnumerable<SyntaxTrivia>? trivia = null,
            Func<SyntaxTrivia, SyntaxTrivia, SyntaxTrivia>? computeReplacementTrivia = null)
            where TNode : SyntaxNode
        {
            var replacer = new Replacer<TNode>(
                nodes, computeReplacementNode,
                tokens, computeReplacementToken,
                trivia, computeReplacementTrivia);
 
            if (replacer.HasWork)
            {
                return replacer.Visit(root);
            }
            else
            {
                return root;
            }
        }
 
        internal static SyntaxToken Replace(
            SyntaxToken root,
            IEnumerable<SyntaxNode>? nodes = null,
            Func<SyntaxNode, SyntaxNode, SyntaxNode>? computeReplacementNode = null,
            IEnumerable<SyntaxToken>? tokens = null,
            Func<SyntaxToken, SyntaxToken, SyntaxToken>? computeReplacementToken = null,
            IEnumerable<SyntaxTrivia>? trivia = null,
            Func<SyntaxTrivia, SyntaxTrivia, SyntaxTrivia>? computeReplacementTrivia = null)
        {
            var replacer = new Replacer<SyntaxNode>(
                nodes, computeReplacementNode,
                tokens, computeReplacementToken,
                trivia, computeReplacementTrivia);
 
            if (replacer.HasWork)
            {
                return replacer.VisitToken(root);
            }
            else
            {
                return root;
            }
        }
 
        private class Replacer<TNode> : CSharpSyntaxRewriter where TNode : SyntaxNode
        {
            private readonly Func<TNode, TNode, SyntaxNode>? _computeReplacementNode;
            private readonly Func<SyntaxToken, SyntaxToken, SyntaxToken>? _computeReplacementToken;
            private readonly Func<SyntaxTrivia, SyntaxTrivia, SyntaxTrivia>? _computeReplacementTrivia;
 
            private readonly HashSet<SyntaxNode> _nodeSet;
            private readonly HashSet<SyntaxToken> _tokenSet;
            private readonly HashSet<SyntaxTrivia> _triviaSet;
            private readonly HashSet<TextSpan> _spanSet;
 
            private readonly TextSpan _totalSpan;
            private readonly bool _visitIntoStructuredTrivia;
            private readonly bool _shouldVisitTrivia;
 
            public Replacer(
                IEnumerable<TNode>? nodes,
                Func<TNode, TNode, SyntaxNode>? computeReplacementNode,
                IEnumerable<SyntaxToken>? tokens,
                Func<SyntaxToken, SyntaxToken, SyntaxToken>? computeReplacementToken,
                IEnumerable<SyntaxTrivia>? trivia,
                Func<SyntaxTrivia, SyntaxTrivia, SyntaxTrivia>? computeReplacementTrivia)
            {
                _computeReplacementNode = computeReplacementNode;
                _computeReplacementToken = computeReplacementToken;
                _computeReplacementTrivia = computeReplacementTrivia;
 
                _nodeSet = nodes != null ? new HashSet<SyntaxNode>(nodes) : s_noNodes;
                _tokenSet = tokens != null ? new HashSet<SyntaxToken>(tokens) : s_noTokens;
                _triviaSet = trivia != null ? new HashSet<SyntaxTrivia>(trivia) : s_noTrivia;
 
                _spanSet = new HashSet<TextSpan>(
                    _nodeSet.Select(n => n.FullSpan).Concat(
                    _tokenSet.Select(t => t.FullSpan).Concat(
                    _triviaSet.Select(t => t.FullSpan))));
 
                _totalSpan = ComputeTotalSpan(_spanSet);
 
                _visitIntoStructuredTrivia =
                    _nodeSet.Any(n => n.IsPartOfStructuredTrivia()) ||
                    _tokenSet.Any(t => t.IsPartOfStructuredTrivia()) ||
                    _triviaSet.Any(t => t.IsPartOfStructuredTrivia());
 
                _shouldVisitTrivia = _triviaSet.Count > 0 || _visitIntoStructuredTrivia;
            }
 
            private static readonly HashSet<SyntaxNode> s_noNodes = new HashSet<SyntaxNode>();
            private static readonly HashSet<SyntaxToken> s_noTokens = new HashSet<SyntaxToken>();
            private static readonly HashSet<SyntaxTrivia> s_noTrivia = new HashSet<SyntaxTrivia>();
 
            public override bool VisitIntoStructuredTrivia
            {
                get
                {
                    return _visitIntoStructuredTrivia;
                }
            }
 
            public bool HasWork
            {
                get
                {
                    return _nodeSet.Count + _tokenSet.Count + _triviaSet.Count > 0;
                }
            }
 
            private static TextSpan ComputeTotalSpan(IEnumerable<TextSpan> spans)
            {
                bool first = true;
                int start = 0;
                int end = 0;
 
                foreach (var span in spans)
                {
                    if (first)
                    {
                        start = span.Start;
                        end = span.End;
                        first = false;
                    }
                    else
                    {
                        start = Math.Min(start, span.Start);
                        end = Math.Max(end, span.End);
                    }
                }
 
                return new TextSpan(start, end - start);
            }
 
            private bool ShouldVisit(TextSpan span)
            {
                // first do quick check against total span
                if (!span.IntersectsWith(_totalSpan))
                {
                    // if the node is outside the total span of the nodes to be replaced
                    // then we won't find any nodes to replace below it.
                    return false;
                }
 
                foreach (var s in _spanSet)
                {
                    if (span.IntersectsWith(s))
                    {
                        // node's full span intersects with at least one node to be replaced
                        // so we need to visit node's children to find it.
                        return true;
                    }
                }
 
                return false;
            }
 
            [return: NotNullIfNotNull(nameof(node))]
            public override SyntaxNode? Visit(SyntaxNode? node)
            {
                SyntaxNode? rewritten = node;
 
                if (node != null)
                {
                    if (this.ShouldVisit(node.FullSpan))
                    {
                        rewritten = base.Visit(node);
                    }
 
                    if (_nodeSet.Contains(node) && _computeReplacementNode != null)
                    {
                        rewritten = _computeReplacementNode((TNode)node, (TNode)rewritten!);
                    }
                }
 
                return rewritten;
            }
 
            public override SyntaxToken VisitToken(SyntaxToken token)
            {
                var rewritten = token;
 
                if (_shouldVisitTrivia && this.ShouldVisit(token.FullSpan))
                {
                    rewritten = base.VisitToken(token);
                }
 
                if (_tokenSet.Contains(token) && _computeReplacementToken != null)
                {
                    rewritten = _computeReplacementToken(token, rewritten);
                }
 
                return rewritten;
            }
 
            public override SyntaxTrivia VisitListElement(SyntaxTrivia trivia)
            {
                var rewritten = trivia;
 
                if (this.VisitIntoStructuredTrivia && trivia.HasStructure && this.ShouldVisit(trivia.FullSpan))
                {
                    rewritten = this.VisitTrivia(trivia);
                }
 
                if (_triviaSet.Contains(trivia) && _computeReplacementTrivia != null)
                {
                    rewritten = _computeReplacementTrivia(trivia, rewritten);
                }
 
                return rewritten;
            }
        }
 
        internal static SyntaxNode ReplaceNodeInList(SyntaxNode root, SyntaxNode originalNode, IEnumerable<SyntaxNode> newNodes)
        {
            return new NodeListEditor(originalNode, newNodes, ListEditKind.Replace).Visit(root);
        }
 
        internal static SyntaxNode InsertNodeInList(SyntaxNode root, SyntaxNode nodeInList, IEnumerable<SyntaxNode> nodesToInsert, bool insertBefore)
        {
            return new NodeListEditor(nodeInList, nodesToInsert, insertBefore ? ListEditKind.InsertBefore : ListEditKind.InsertAfter).Visit(root);
        }
 
        public static SyntaxNode ReplaceTokenInList(SyntaxNode root, SyntaxToken tokenInList, IEnumerable<SyntaxToken> newTokens)
        {
            return new TokenListEditor(tokenInList, newTokens, ListEditKind.Replace).Visit(root);
        }
 
        public static SyntaxNode InsertTokenInList(SyntaxNode root, SyntaxToken tokenInList, IEnumerable<SyntaxToken> newTokens, bool insertBefore)
        {
            return new TokenListEditor(tokenInList, newTokens, insertBefore ? ListEditKind.InsertBefore : ListEditKind.InsertAfter).Visit(root);
        }
 
        public static SyntaxNode ReplaceTriviaInList(SyntaxNode root, SyntaxTrivia triviaInList, IEnumerable<SyntaxTrivia> newTrivia)
        {
            return new TriviaListEditor(triviaInList, newTrivia, ListEditKind.Replace).Visit(root);
        }
 
        public static SyntaxNode InsertTriviaInList(SyntaxNode root, SyntaxTrivia triviaInList, IEnumerable<SyntaxTrivia> newTrivia, bool insertBefore)
        {
            return new TriviaListEditor(triviaInList, newTrivia, insertBefore ? ListEditKind.InsertBefore : ListEditKind.InsertAfter).Visit(root);
        }
 
        public static SyntaxToken ReplaceTriviaInList(SyntaxToken root, SyntaxTrivia triviaInList, IEnumerable<SyntaxTrivia> newTrivia)
        {
            return new TriviaListEditor(triviaInList, newTrivia, ListEditKind.Replace).VisitToken(root);
        }
 
        public static SyntaxToken InsertTriviaInList(SyntaxToken root, SyntaxTrivia triviaInList, IEnumerable<SyntaxTrivia> newTrivia, bool insertBefore)
        {
            return new TriviaListEditor(triviaInList, newTrivia, insertBefore ? ListEditKind.InsertBefore : ListEditKind.InsertAfter).VisitToken(root);
        }
 
        private enum ListEditKind
        {
            InsertBefore,
            InsertAfter,
            Replace
        }
 
        private static InvalidOperationException GetItemNotListElementException()
        {
            return new InvalidOperationException(CodeAnalysisResources.MissingListItem);
        }
 
        private abstract class BaseListEditor : CSharpSyntaxRewriter
        {
            private readonly TextSpan _elementSpan;
            private readonly bool _visitTrivia;
            private readonly bool _visitIntoStructuredTrivia;
 
            protected readonly ListEditKind editKind;
 
            public BaseListEditor(
                TextSpan elementSpan,
                ListEditKind editKind,
                bool visitTrivia,
                bool visitIntoStructuredTrivia)
            {
                _elementSpan = elementSpan;
                this.editKind = editKind;
                _visitTrivia = visitTrivia || visitIntoStructuredTrivia;
                _visitIntoStructuredTrivia = visitIntoStructuredTrivia;
            }
 
            public override bool VisitIntoStructuredTrivia
            {
                get
                {
                    return _visitIntoStructuredTrivia;
                }
            }
 
            private bool ShouldVisit(TextSpan span)
            {
                if (span.IntersectsWith(_elementSpan))
                {
                    // node's full span intersects with at least one node to be replaced
                    // so we need to visit node's children to find it.
                    return true;
                }
 
                return false;
            }
 
            [return: NotNullIfNotNull(nameof(node))]
            public override SyntaxNode? Visit(SyntaxNode? node)
            {
                SyntaxNode? rewritten = node;
 
                if (node != null)
                {
                    if (this.ShouldVisit(node.FullSpan))
                    {
                        rewritten = base.Visit(node);
                    }
                }
 
                return rewritten;
            }
 
            public override SyntaxToken VisitToken(SyntaxToken token)
            {
                var rewritten = token;
 
                if (_visitTrivia && this.ShouldVisit(token.FullSpan))
                {
                    rewritten = base.VisitToken(token);
                }
 
                return rewritten;
            }
 
            public override SyntaxTrivia VisitListElement(SyntaxTrivia trivia)
            {
                var rewritten = trivia;
 
                if (this.VisitIntoStructuredTrivia && trivia.HasStructure && this.ShouldVisit(trivia.FullSpan))
                {
                    rewritten = this.VisitTrivia(trivia);
                }
 
                return rewritten;
            }
        }
 
        private class NodeListEditor : BaseListEditor
        {
            private readonly SyntaxNode _originalNode;
            private readonly IEnumerable<SyntaxNode> _newNodes;
 
            public NodeListEditor(
                SyntaxNode originalNode,
                IEnumerable<SyntaxNode> replacementNodes,
                ListEditKind editKind)
                : base(originalNode.Span, editKind, false, originalNode.IsPartOfStructuredTrivia())
            {
                _originalNode = originalNode;
                _newNodes = replacementNodes;
            }
 
            [return: NotNullIfNotNull(nameof(node))]
            public override SyntaxNode? Visit(SyntaxNode? node)
            {
                if (node == _originalNode)
                {
                    throw GetItemNotListElementException();
                }
 
                return base.Visit(node);
            }
 
            public override SeparatedSyntaxList<TNode> VisitList<TNode>(SeparatedSyntaxList<TNode> list)
            {
                if (_originalNode is TNode)
                {
                    var index = list.IndexOf((TNode)_originalNode);
                    if (index >= 0 && index < list.Count)
                    {
                        switch (this.editKind)
                        {
                            case ListEditKind.Replace:
                                return list.ReplaceRange((TNode)_originalNode, _newNodes.Cast<TNode>());
 
                            case ListEditKind.InsertAfter:
                                return list.InsertRange(index + 1, _newNodes.Cast<TNode>());
 
                            case ListEditKind.InsertBefore:
                                return list.InsertRange(index, _newNodes.Cast<TNode>());
                        }
                    }
                }
 
                return base.VisitList<TNode>(list);
            }
 
            public override SyntaxList<TNode> VisitList<TNode>(SyntaxList<TNode> list)
            {
                if (_originalNode is TNode)
                {
                    var index = list.IndexOf((TNode)_originalNode);
                    if (index >= 0 && index < list.Count)
                    {
                        switch (this.editKind)
                        {
                            case ListEditKind.Replace:
                                return list.ReplaceRange((TNode)_originalNode, _newNodes.Cast<TNode>());
 
                            case ListEditKind.InsertAfter:
                                return list.InsertRange(index + 1, _newNodes.Cast<TNode>());
 
                            case ListEditKind.InsertBefore:
                                return list.InsertRange(index, _newNodes.Cast<TNode>());
                        }
                    }
                }
 
                return base.VisitList<TNode>(list);
            }
        }
 
        private class TokenListEditor : BaseListEditor
        {
            private readonly SyntaxToken _originalToken;
            private readonly IEnumerable<SyntaxToken> _newTokens;
 
            public TokenListEditor(
                SyntaxToken originalToken,
                IEnumerable<SyntaxToken> newTokens,
                ListEditKind editKind)
                : base(originalToken.Span, editKind, false, originalToken.IsPartOfStructuredTrivia())
            {
                _originalToken = originalToken;
                _newTokens = newTokens;
            }
 
            public override SyntaxToken VisitToken(SyntaxToken token)
            {
                if (token == _originalToken)
                {
                    throw GetItemNotListElementException();
                }
 
                return base.VisitToken(token);
            }
 
            public override SyntaxTokenList VisitList(SyntaxTokenList list)
            {
                var index = list.IndexOf(_originalToken);
                if (index >= 0 && index < list.Count)
                {
                    switch (this.editKind)
                    {
                        case ListEditKind.Replace:
                            return list.ReplaceRange(_originalToken, _newTokens);
 
                        case ListEditKind.InsertAfter:
                            return list.InsertRange(index + 1, _newTokens);
 
                        case ListEditKind.InsertBefore:
                            return list.InsertRange(index, _newTokens);
                    }
                }
 
                return base.VisitList(list);
            }
        }
 
        private class TriviaListEditor : BaseListEditor
        {
            private readonly SyntaxTrivia _originalTrivia;
            private readonly IEnumerable<SyntaxTrivia> _newTrivia;
 
            public TriviaListEditor(
                SyntaxTrivia originalTrivia,
                IEnumerable<SyntaxTrivia> newTrivia,
                ListEditKind editKind)
                : base(originalTrivia.Span, editKind, true, originalTrivia.IsPartOfStructuredTrivia())
            {
                _originalTrivia = originalTrivia;
                _newTrivia = newTrivia;
            }
 
            public override SyntaxTriviaList VisitList(SyntaxTriviaList list)
            {
                var index = list.IndexOf(_originalTrivia);
                if (index >= 0 && index < list.Count)
                {
                    switch (this.editKind)
                    {
                        case ListEditKind.Replace:
                            return list.ReplaceRange(_originalTrivia, _newTrivia);
 
                        case ListEditKind.InsertAfter:
                            return list.InsertRange(index + 1, _newTrivia);
 
                        case ListEditKind.InsertBefore:
                            return list.InsertRange(index, _newTrivia);
                    }
                }
 
                return base.VisitList(list);
            }
        }
    }
}