File: Language\Syntax\SyntaxReplacer.cs
Web Access
Project: src\src\Razor\src\Compiler\Microsoft.CodeAnalysis.Razor.Compiler\src\Microsoft.CodeAnalysis.Razor.Compiler.csproj (Microsoft.CodeAnalysis.Razor.Compiler)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.CodeAnalysis.Text;
 
namespace Microsoft.AspNetCore.Razor.Language.Syntax;
 
internal static class SyntaxReplacer
{
    internal static SyntaxNode Replace<TNode>(
        SyntaxNode root,
        IEnumerable<TNode>? nodes = null,
        Func<TNode, TNode, SyntaxNode>? computeReplacementNode = null,
        IEnumerable<SyntaxToken>? tokens = null,
        Func<SyntaxToken, SyntaxToken, SyntaxToken>? computeReplacementToken = null)
        where TNode : SyntaxNode
    {
        var replacer = new Replacer<TNode>(nodes, computeReplacementNode, tokens, computeReplacementToken);
 
        if (replacer.HasWork)
        {
            return replacer.Visit(root);
        }
        else
        {
            return root;
        }
    }
 
    internal static SyntaxToken Replace(
        SyntaxToken root,
        IEnumerable<SyntaxNode>? nodes = null,
        Func<SyntaxNode, SyntaxNode, SyntaxNode>? computeReplacementNode = null,
        IEnumerable<SyntaxToken>? tokens = null,
        Func<SyntaxToken, SyntaxToken, SyntaxToken>? computeReplacementToken = null)
    {
        var replacer = new Replacer<SyntaxNode>(nodes, computeReplacementNode, tokens, computeReplacementToken);
 
        if (replacer.HasWork)
        {
            return replacer.VisitToken(root);
        }
        else
        {
            return root;
        }
    }
 
    private sealed class Replacer<TNode> : SyntaxRewriter where TNode : SyntaxNode
    {
        private static readonly HashSet<SyntaxNode> s_noNodes = [];
        private static readonly HashSet<SyntaxToken> s_noTokens = [];
 
        private readonly Func<TNode, TNode, SyntaxNode>? _computeReplacementNode;
        private readonly Func<SyntaxToken, SyntaxToken, SyntaxToken>? _computeReplacementToken;
 
        private readonly HashSet<SyntaxNode> _nodeSet;
        private readonly HashSet<SyntaxToken> _tokenSet;
        private readonly HashSet<TextSpan> _spanSet;
 
        private TextSpan _totalSpan;
 
        public Replacer(
            IEnumerable<TNode>? nodes = null,
            Func<TNode, TNode, SyntaxNode>? computeReplacementNode = null,
            IEnumerable<SyntaxToken>? tokens = null,
            Func<SyntaxToken, SyntaxToken, SyntaxToken>? computeReplacementToken = null)
        {
            _computeReplacementNode = computeReplacementNode;
            _computeReplacementToken = computeReplacementToken;
 
            _nodeSet = nodes != null ? [.. nodes] : s_noNodes;
            _tokenSet = tokens != null ? [.. tokens] : s_noTokens;
 
            _spanSet = [];
            CalculateVisitationCriteria();
        }
 
        public bool HasWork => _nodeSet.Count + _tokenSet.Count > 0;
 
        private void CalculateVisitationCriteria()
        {
            _spanSet.Clear();
            foreach (var node in _nodeSet)
            {
                _spanSet.Add(node.Span);
            }
 
            foreach (var token in _tokenSet)
            {
                _spanSet.Add(token.Span);
            }
 
            var first = true;
            var start = 0;
            var end = 0;
 
            foreach (var span in _spanSet)
            {
                if (first)
                {
                    start = span.Start;
                    end = span.End;
                    first = false;
                }
                else
                {
                    start = Math.Min(start, span.Start);
                    end = Math.Max(end, span.End);
                }
            }
 
            _totalSpan = new TextSpan(start, end - start);
        }
 
        private bool ShouldVisit(TextSpan span)
        {
            // first do quick check against total span
            if (!span.IntersectsWith(_totalSpan))
            {
                // if the node is outside the total span of the nodes to be replaced
                // then we won't find any nodes to replace below it.
                return false;
            }
 
            foreach (var s in _spanSet)
            {
                if (span.IntersectsWith(s))
                {
                    // node's full span intersects with at least one node to be replaced
                    // so we need to visit node's children to find it.
                    return true;
                }
            }
 
            return false;
        }
 
        [return: NotNullIfNotNull(nameof(node))]
        public override SyntaxNode? Visit(SyntaxNode? node)
        {
            var rewritten = node;
 
            if (node != null)
            {
                var isReplacedNode = _nodeSet.Remove(node);
 
                if (isReplacedNode)
                {
                    // If node is in _nodeSet, then it contributed to the calculation of _spanSet.
                    // We are currently processing that node, so it no longer needs to contribute
                    // to _spanSet and affect determination of inward visitation. This is done before
                    // calling ShouldVisit to avoid walking into the node if there aren't any remaining
                    // spans inside it representing items to replace.
                    CalculateVisitationCriteria();
                }
 
                if (ShouldVisit(node.Span))
                {
                    rewritten = base.Visit(node);
                }
 
                if (isReplacedNode && _computeReplacementNode != null)
                {
                    rewritten = _computeReplacementNode((TNode)node, (TNode)rewritten!);
                }
            }
 
            return rewritten;
        }
 
        public override SyntaxToken VisitToken(SyntaxToken token)
        {
            var rewritten = token;
            var isReplacedToken = _tokenSet.Remove(token);
 
            if (isReplacedToken)
            {
                // If token is in _tokenSet, then it contributed to the calculation of _spanSet.
                // We are currently processing that token, so it no longer needs to contribute
                // to _spanSet and affect determination of inward visitation. This is done before
                // calling ShouldVisit to avoid walking into the token if there aren't any remaining
                // spans inside it representing items to replace.
                CalculateVisitationCriteria();
            }
 
            if (isReplacedToken && _computeReplacementToken != null)
            {
                rewritten = _computeReplacementToken(token, rewritten);
            }
 
            return rewritten;
        }
    }
 
    internal static SyntaxNode ReplaceNodeInList(SyntaxNode root, SyntaxNode originalNode, IEnumerable<SyntaxNode> newNodes)
    {
        return new NodeListEditor(originalNode, newNodes, ListEditKind.Replace).Visit(root);
    }
 
    internal static SyntaxNode InsertNodeInList(SyntaxNode root, SyntaxNode nodeInList, IEnumerable<SyntaxNode> nodesToInsert, bool insertBefore)
    {
        return new NodeListEditor(nodeInList, nodesToInsert, insertBefore ? ListEditKind.InsertBefore : ListEditKind.InsertAfter).Visit(root);
    }
 
    public static SyntaxNode ReplaceTokenInList(SyntaxNode root, SyntaxToken tokenInList, IEnumerable<SyntaxToken> newTokens)
    {
        return new TokenListEditor(tokenInList, newTokens, ListEditKind.Replace).Visit(root);
    }
 
    public static SyntaxNode InsertTokenInList(SyntaxNode root, SyntaxToken tokenInList, IEnumerable<SyntaxToken> newTokens, bool insertBefore)
    {
        return new TokenListEditor(tokenInList, newTokens, insertBefore ? ListEditKind.InsertBefore : ListEditKind.InsertAfter).Visit(root);
    }
 
    private enum ListEditKind
    {
        InsertBefore,
        InsertAfter,
        Replace
    }
 
    private abstract class BaseListEditor : SyntaxRewriter
    {
        private readonly TextSpan _elementSpan;
 
        protected readonly ListEditKind EditKind;
 
        protected BaseListEditor(TextSpan elementSpan, ListEditKind editKind)
        {
            _elementSpan = elementSpan;
            EditKind = editKind;
        }
 
        private bool ShouldVisit(TextSpan span)
        {
            if (span.IntersectsWith(_elementSpan))
            {
                // node's full span intersects with at least one node to be replaced
                // so we need to visit node's children to find it.
                return true;
            }
 
            return false;
        }
 
        [return: NotNullIfNotNull(nameof(node))]
        public override SyntaxNode? Visit(SyntaxNode? node)
        {
            SyntaxNode? rewritten = node;
 
            if (node != null)
            {
                if (ShouldVisit(node.Span))
                {
                    rewritten = base.Visit(node);
                }
            }
 
            return rewritten;
        }
    }
 
    private sealed class NodeListEditor : BaseListEditor
    {
        private readonly SyntaxNode _originalNode;
        private readonly IEnumerable<SyntaxNode> _newNodes;
 
        public NodeListEditor(
            SyntaxNode originalNode,
            IEnumerable<SyntaxNode> replacementNodes,
            ListEditKind editKind)
            : base(originalNode.Span, editKind)
        {
            _originalNode = originalNode;
            _newNodes = replacementNodes;
        }
 
        [return: NotNullIfNotNull(nameof(node))]
        public override SyntaxNode? Visit(SyntaxNode? node)
        {
            if (node == _originalNode)
            {
                throw new InvalidOperationException("Expecting a list");
            }
 
            return base.Visit(node);
        }
 
        public override SyntaxList<TNode> VisitList<TNode>(SyntaxList<TNode> list)
        {
            if (_originalNode is TNode)
            {
                var index = list.IndexOf((TNode)_originalNode);
                if (index >= 0 && index < list.Count)
                {
                    switch (EditKind)
                    {
                        case ListEditKind.Replace:
                            return list.ReplaceRange((TNode)_originalNode, _newNodes.Cast<TNode>());
 
                        case ListEditKind.InsertAfter:
                            return list.InsertRange(index + 1, _newNodes.Cast<TNode>());
 
                        case ListEditKind.InsertBefore:
                            return list.InsertRange(index, _newNodes.Cast<TNode>());
                    }
                }
            }
 
            return base.VisitList(list);
        }
    }
 
    private class TokenListEditor : BaseListEditor
    {
        private readonly SyntaxToken _originalToken;
        private readonly IEnumerable<SyntaxToken> _newTokens;
 
        public TokenListEditor(
            SyntaxToken originalToken,
            IEnumerable<SyntaxToken> newTokens,
            ListEditKind editKind)
            : base(originalToken.Span, editKind)
        {
            _originalToken = originalToken;
            _newTokens = newTokens;
        }
 
        public override SyntaxToken VisitToken(SyntaxToken token)
        {
            if (token == _originalToken)
            {
                throw new InvalidOperationException("Expecting a list");
            }
 
            return base.VisitToken(token);
        }
 
        public override SyntaxTokenList VisitList(SyntaxTokenList list)
        {
            var index = list.IndexOf(_originalToken);
            if (index >= 0 && index < list.Count)
            {
                switch (EditKind)
                {
                    case ListEditKind.Replace:
                        return list.ReplaceRange(_originalToken, _newTokens);
 
                    case ListEditKind.InsertAfter:
                        return list.InsertRange(index + 1, _newTokens);
 
                    case ListEditKind.InsertBefore:
                        return list.InsertRange(index, _newTokens);
                }
            }
 
            return base.VisitList(list);
        }
    }
}