// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.LanguageService; using Microsoft.CodeAnalysis.PooledObjects; using Microsoft.CodeAnalysis.Text; using Roslyn.Utilities; namespace Microsoft.CodeAnalysis.Shared.Extensions; internal static partial class SyntaxNodeExtensions { public static SyntaxNode GetRequiredParent(this SyntaxNode node) => node.Parent ?? throw new InvalidOperationException("Node's parent was null"); public static IEnumerable<SyntaxNodeOrToken> DepthFirstTraversal(this SyntaxNode node) => SyntaxNodeOrTokenExtensions.DepthFirstTraversal(node); public static IEnumerable<SyntaxNode> DepthFirstTraversalNodes(this SyntaxNode node) => SyntaxNodeOrTokenExtensions.DepthFirstTraversalNodes(node); public static IEnumerable<SyntaxNode> GetAncestors(this SyntaxNode node) { var current = node.Parent; while (current != null) { yield return current; current = current.GetParent(ascendOutOfTrivia: true); } } public static IEnumerable<TNode> GetAncestors<TNode>(this SyntaxNode node) where TNode : SyntaxNode { var current = node.Parent; while (current != null) { if (current is TNode tNode) { yield return tNode; } current = current.GetParent(ascendOutOfTrivia: true); } } public static TNode? GetAncestor<TNode>(this SyntaxNode node) where TNode : SyntaxNode { var current = node.Parent; while (current != null) { if (current is TNode tNode) { return tNode; } current = current.GetParent(ascendOutOfTrivia: true); } return null; } public static TNode? GetAncestorOrThis<TNode>(this SyntaxNode? node) where TNode : SyntaxNode { return node?.GetAncestorsOrThis<TNode>().FirstOrDefault(); } public static IEnumerable<TNode> GetAncestorsOrThis<TNode>(this SyntaxNode? node) where TNode : SyntaxNode { var current = node; while (current != null) { if (current is TNode tNode) { yield return tNode; } current = current.GetParent(ascendOutOfTrivia: true); } } public static bool HasAncestor<TNode>(this SyntaxNode node) where TNode : SyntaxNode { return node.GetAncestors<TNode>().Any(); } public static IEnumerable<TSyntaxNode> Traverse<TSyntaxNode>( this SyntaxNode node, TextSpan searchSpan, Func<SyntaxNode, bool> predicate) where TSyntaxNode : SyntaxNode { Contract.ThrowIfNull(node); var nodes = new LinkedList<SyntaxNode>(); nodes.AddFirst(node); while (nodes.Count > 0) { var currentNode = nodes.First!.Value; nodes.RemoveFirst(); if (currentNode != null && searchSpan.Contains(currentNode.FullSpan) && predicate(currentNode)) { if (currentNode is TSyntaxNode tSyntax) { yield return tSyntax; } nodes.AddRangeAtHead(currentNode.ChildNodes()); } } } public static bool CheckParent<T>([NotNullWhen(returnValue: true)] this SyntaxNode? node, Func<T, bool> valueChecker) where T : SyntaxNode { if (node?.Parent is not T parentNode) { return false; } return valueChecker(parentNode); } /// <summary> /// Returns true if is a given token is a child token of a certain type of parent node. /// </summary> /// <typeparam name="TParent">The type of the parent node.</typeparam> /// <param name="node">The node that we are testing.</param> /// <param name="childGetter">A function that, when given the parent node, returns the child token we are interested in.</param> public static bool IsChildNode<TParent>(this SyntaxNode node, Func<TParent, SyntaxNode?> childGetter) where TParent : SyntaxNode { var ancestor = node.GetAncestor<TParent>(); if (ancestor == null) { return false; } var ancestorNode = childGetter(ancestor); return node == ancestorNode; } /// <summary> /// Returns true if this node is found underneath the specified child in the given parent. /// </summary> public static bool IsFoundUnder<TParent>(this SyntaxNode node, Func<TParent, SyntaxNode?> childGetter) where TParent : SyntaxNode { var ancestor = node.GetAncestor<TParent>(); if (ancestor == null) { return false; } var child = childGetter(ancestor); // See if node passes through child on the way up to ancestor. return node.GetAncestorsOrThis<SyntaxNode>().Contains(child); } public static SyntaxNode GetCommonRoot(this SyntaxNode node1, SyntaxNode node2) { Contract.ThrowIfTrue(node1.RawKind == 0 || node2.RawKind == 0); // find common starting node from two nodes. as long as two nodes belong to same tree, there must be at least // one common root (Ex, compilation unit) using var _ = PooledHashSet<SyntaxNode>.GetInstance(out var set); set.AddRange(node2.GetAncestorsOrThis<SyntaxNode>()); foreach (var ancestor in node1.AncestorsAndSelf()) { if (set.Contains(ancestor)) return ancestor; } throw ExceptionUtilities.Unreachable(); } public static int Width(this SyntaxNode node) => node.Span.Length; public static int FullWidth(this SyntaxNode node) => node.FullSpan.Length; public static SyntaxNode? FindInnermostCommonNode(this IEnumerable<SyntaxNode> nodes, Func<SyntaxNode, bool> predicate) => nodes.FindInnermostCommonNode()?.FirstAncestorOrSelf(predicate); public static SyntaxNode? FindInnermostCommonNode(this IEnumerable<SyntaxNode> nodes) { // Two collections we use to make this operation as efficient as possible. One is a // stack of the current shared ancestor chain shared by all nodes so far. It starts // with the full ancestor chain of the first node, and can only get smaller over time. // It should be log(n) with the size of the tree as it's only storing a parent chain. // // The second is a set with the exact same contents as the array. It's used for O(1) // lookups if a node is in the ancestor chain or not. using var _1 = ArrayBuilder<SyntaxNode>.GetInstance(out var commonAncestorsStack); using var _2 = PooledHashSet<SyntaxNode>.GetInstance(out var commonAncestorsSet); var first = true; foreach (var node in nodes) { // If we're just starting, initialize the ancestors set/array with the ancestors of // this node. if (first) { first = false; foreach (var ancestor in node.ValueAncestorsAndSelf()) { commonAncestorsSet.Add(ancestor); commonAncestorsStack.Add(ancestor); } // Reverse the ancestors stack so that we go downwards with CompilationUnit at // the start, and then go down to this starting node. This enables cheap // popping later on. commonAncestorsStack.ReverseContents(); continue; } // On a subsequent node, walk its ancestors to find the first match var commonAncestor = FindCommonAncestor(node, commonAncestorsSet); if (commonAncestor == null) { // So this shouldn't happen as long as the nodes are from the same tree. And // the caller really shouldn't be calling from different trees. However, the // previous impl supported that, so continue to have this behavior. // // If this doesn't fire, that means that all callers seem sane. If it does // fire, we can relax this (but we should consider fixing the caller). Debug.Fail("Could not find common ancestor."); return null; } // Now remove everything in the ancestors array up to that common ancestor. This is // generally quite efficient. Either we settle on a common node quickly. and don't // need to do work here, or we keep tossing data from our common-ancestor scratch // pad, making further work faster. while (commonAncestorsStack.Count > 0 && commonAncestorsStack.Peek() != commonAncestor) { commonAncestorsSet.Remove(commonAncestorsStack.Peek()); commonAncestorsStack.Pop(); } if (commonAncestorsStack.Count == 0) { // So this shouldn't happen as long as the nodes are from the same tree. And // the caller really shouldn't be calling from different trees. However, the // previous impl supported that, so continue to have this behavior. // // If this doesn't fire, that means that all callers seem sane. If it does // fire, we can relax this (but we should consider fixing the caller). Debug.Fail("Could not find common ancestor."); return null; } } // The common ancestor is the one at the end of the ancestor stack. This could be empty // in the case where the caller passed in an empty enumerable of nodes. return commonAncestorsStack.Count == 0 ? null : commonAncestorsStack.Peek(); // local functions static SyntaxNode? FindCommonAncestor(SyntaxNode node, HashSet<SyntaxNode> commonAncestorsSet) { foreach (var ancestor in node.ValueAncestorsAndSelf()) { if (commonAncestorsSet.Contains(ancestor)) return ancestor; } return null; } } public static TSyntaxNode? FindInnermostCommonNode<TSyntaxNode>(this IEnumerable<SyntaxNode> nodes) where TSyntaxNode : SyntaxNode => (TSyntaxNode?)nodes.FindInnermostCommonNode(t => t is TSyntaxNode); public static TextSpan GetContainedSpan(this IEnumerable<SyntaxNode> nodes) { Contract.ThrowIfNull(nodes); Contract.ThrowIfFalse(nodes.Any()); var fullSpan = nodes.First().Span; foreach (var node in nodes) { fullSpan = TextSpan.FromBounds( Math.Min(fullSpan.Start, node.SpanStart), Math.Max(fullSpan.End, node.Span.End)); } return fullSpan; } public static bool OverlapsHiddenPosition(this SyntaxNode node, CancellationToken cancellationToken) => node.OverlapsHiddenPosition(node.Span, cancellationToken); public static bool OverlapsHiddenPosition(this SyntaxNode node, TextSpan span, CancellationToken cancellationToken) => node.SyntaxTree.OverlapsHiddenPosition(span, cancellationToken); public static bool OverlapsHiddenPosition(this SyntaxNode declaration, SyntaxNode startNode, SyntaxNode endNode, CancellationToken cancellationToken) { var start = startNode.Span.End; var end = endNode.SpanStart; var textSpan = TextSpan.FromBounds(start, end); return declaration.OverlapsHiddenPosition(textSpan, cancellationToken); } public static IEnumerable<T> GetAnnotatedNodes<T>(this SyntaxNode node, SyntaxAnnotation syntaxAnnotation) where T : SyntaxNode => node.GetAnnotatedNodesAndTokens(syntaxAnnotation).Select(n => n.AsNode()).OfType<T>(); /// <summary> /// Creates a new tree of nodes from the existing tree with the specified old nodes replaced with a newly computed nodes. /// </summary> /// <param name="root">The root of the tree that contains all the specified nodes.</param> /// <param name="nodes">The nodes from the tree to be replaced.</param> /// <param name="computeReplacementAsync">A function that computes a replacement node for /// the argument nodes. The first argument is one of the original specified nodes. The second argument is /// the same node possibly rewritten with replaced descendants.</param> /// <param name="cancellationToken"></param> public static Task<TRootNode> ReplaceNodesAsync<TRootNode>( this TRootNode root, IEnumerable<SyntaxNode> nodes, Func<SyntaxNode, SyntaxNode, CancellationToken, Task<SyntaxNode>> computeReplacementAsync, CancellationToken cancellationToken) where TRootNode : SyntaxNode { return root.ReplaceSyntaxAsync( nodes: nodes, computeReplacementNodeAsync: computeReplacementAsync, tokens: null, computeReplacementTokenAsync: null, trivia: null, computeReplacementTriviaAsync: null, cancellationToken: cancellationToken); } /// <summary> /// Creates a new tree of tokens from the existing tree with the specified old tokens replaced with a newly computed tokens. /// </summary> /// <param name="root">The root of the tree that contains all the specified tokens.</param> /// <param name="tokens">The tokens from the tree to be replaced.</param> /// <param name="computeReplacementAsync">A function that computes a replacement token for /// the argument tokens. The first argument is one of the originally specified tokens. The second argument is /// the same token possibly rewritten with replaced trivia.</param> /// <param name="cancellationToken"></param> public static Task<TRootNode> ReplaceTokensAsync<TRootNode>( this TRootNode root, IEnumerable<SyntaxToken> tokens, Func<SyntaxToken, SyntaxToken, CancellationToken, Task<SyntaxToken>> computeReplacementAsync, CancellationToken cancellationToken) where TRootNode : SyntaxNode { return root.ReplaceSyntaxAsync( nodes: null, computeReplacementNodeAsync: null, tokens: tokens, computeReplacementTokenAsync: computeReplacementAsync, trivia: null, computeReplacementTriviaAsync: null, cancellationToken: cancellationToken); } public static Task<TRoot> ReplaceTriviaAsync<TRoot>( this TRoot root, IEnumerable<SyntaxTrivia> trivia, Func<SyntaxTrivia, SyntaxTrivia, CancellationToken, Task<SyntaxTrivia>> computeReplacementAsync, CancellationToken cancellationToken) where TRoot : SyntaxNode { return root.ReplaceSyntaxAsync( nodes: null, computeReplacementNodeAsync: null, tokens: null, computeReplacementTokenAsync: null, trivia: trivia, computeReplacementTriviaAsync: computeReplacementAsync, cancellationToken: cancellationToken); } public static async Task<TRoot> ReplaceSyntaxAsync<TRoot>( this TRoot root, IEnumerable<SyntaxNode>? nodes, Func<SyntaxNode, SyntaxNode, CancellationToken, Task<SyntaxNode>>? computeReplacementNodeAsync, IEnumerable<SyntaxToken>? tokens, Func<SyntaxToken, SyntaxToken, CancellationToken, Task<SyntaxToken>>? computeReplacementTokenAsync, IEnumerable<SyntaxTrivia>? trivia, Func<SyntaxTrivia, SyntaxTrivia, CancellationToken, Task<SyntaxTrivia>>? computeReplacementTriviaAsync, CancellationToken cancellationToken) where TRoot : SyntaxNode { // index all nodes, tokens and trivia by the full spans they cover var nodesToReplace = nodes != null ? nodes.ToDictionary(n => n.FullSpan) : []; var tokensToReplace = tokens != null ? tokens.ToDictionary(t => t.FullSpan) : []; var triviaToReplace = trivia != null ? trivia.ToDictionary(t => t.FullSpan) : []; var nodeReplacements = new Dictionary<SyntaxNode, SyntaxNode>(); var tokenReplacements = new Dictionary<SyntaxToken, SyntaxToken>(); var triviaReplacements = new Dictionary<SyntaxTrivia, SyntaxTrivia>(); var retryAnnotations = new AnnotationTable<object>("RetryReplace"); var spans = new List<TextSpan>(nodesToReplace.Count + tokensToReplace.Count + triviaToReplace.Count); spans.AddRange(nodesToReplace.Keys); spans.AddRange(tokensToReplace.Keys); spans.AddRange(triviaToReplace.Keys); while (spans.Count > 0) { // sort the spans of the items to be replaced so we can tell if any overlap spans.Sort((x, y) => { // order by end offset, and then by length var d = x.End - y.End; if (d == 0) { d = x.Length - y.Length; } return d; }); // compute replacements for all nodes that will go in the same batch // only spans that do not overlap go in the same batch. TextSpan previous = default; foreach (var span in spans) { // only add to replacement map if we don't intersect with the previous node. This taken with the sort order // should ensure that parent nodes are not processed in the same batch as child nodes. if (previous == default || !previous.IntersectsWith(span)) { if (nodesToReplace.TryGetValue(span, out var currentNode)) { var original = (SyntaxNode?)retryAnnotations.GetAnnotations(currentNode).SingleOrDefault() ?? currentNode; var newNode = await computeReplacementNodeAsync!(original, currentNode, cancellationToken).ConfigureAwait(false); nodeReplacements[currentNode] = newNode; } else if (tokensToReplace.TryGetValue(span, out var currentToken)) { var original = (SyntaxToken?)retryAnnotations.GetAnnotations(currentToken).SingleOrDefault() ?? currentToken; var newToken = await computeReplacementTokenAsync!(original, currentToken, cancellationToken).ConfigureAwait(false); tokenReplacements[currentToken] = newToken; } else if (triviaToReplace.TryGetValue(span, out var currentTrivia)) { var original = (SyntaxTrivia?)retryAnnotations.GetAnnotations(currentTrivia).SingleOrDefault() ?? currentTrivia; var newTrivia = await computeReplacementTriviaAsync!(original, currentTrivia, cancellationToken).ConfigureAwait(false); triviaReplacements[currentTrivia] = newTrivia; } } previous = span; } var retryNodes = false; var retryTokens = false; var retryTrivia = false; // replace nodes in batch // submit all nodes so we can annotate the ones we don't replace root = root.ReplaceSyntax( nodes: nodesToReplace.Values, computeReplacementNode: (original, rewritten) => { if (rewritten != original || !nodeReplacements.TryGetValue(original, out var replaced)) { // the subtree did change, or we didn't have a replacement for it in this batch // so we need to add an annotation so we can find this node again for the next batch. replaced = retryAnnotations.WithAdditionalAnnotations(rewritten, original); retryNodes = true; } return replaced; }, tokens: tokensToReplace.Values, computeReplacementToken: (original, rewritten) => { if (rewritten != original || !tokenReplacements.TryGetValue(original, out var replaced)) { // the subtree did change, or we didn't have a replacement for it in this batch // so we need to add an annotation so we can find this node again for the next batch. replaced = retryAnnotations.WithAdditionalAnnotations(rewritten, original); retryTokens = true; } return replaced; }, trivia: triviaToReplace.Values, computeReplacementTrivia: (original, rewritten) => { if (!triviaReplacements.TryGetValue(original, out var replaced)) { // the subtree did change, or we didn't have a replacement for it in this batch // so we need to add an annotation so we can find this node again for the next batch. replaced = retryAnnotations.WithAdditionalAnnotations(rewritten, original); retryTrivia = true; } return replaced; }); nodesToReplace.Clear(); tokensToReplace.Clear(); triviaToReplace.Clear(); spans.Clear(); // prepare next batch out of all remaining annotated nodes if (retryNodes) { nodesToReplace = retryAnnotations.GetAnnotatedNodes(root).ToDictionary(n => n.FullSpan); spans.AddRange(nodesToReplace.Keys); } if (retryTokens) { tokensToReplace = retryAnnotations.GetAnnotatedTokens(root).ToDictionary(t => t.FullSpan); spans.AddRange(tokensToReplace.Keys); } if (retryTrivia) { triviaToReplace = retryAnnotations.GetAnnotatedTrivia(root).ToDictionary(t => t.FullSpan); spans.AddRange(triviaToReplace.Keys); } } return root; } /// <summary> /// Look inside a trivia list for a skipped token that contains the given position. /// </summary> private static readonly Func<SyntaxTriviaList, int, SyntaxToken> s_findSkippedTokenForward = FindSkippedTokenForward; /// <summary> /// Look inside a trivia list for a skipped token that contains the given position. /// </summary> private static SyntaxToken FindSkippedTokenForward(SyntaxTriviaList triviaList, int position) { foreach (var trivia in triviaList) { if (trivia.HasStructure) { if (trivia.GetStructure() is ISkippedTokensTriviaSyntax skippedTokensTrivia) { foreach (var token in skippedTokensTrivia.Tokens) { if (token.Span.Length > 0 && position <= token.Span.End) { return token; } } } } } return default; } /// <summary> /// Look inside a trivia list for a skipped token that contains the given position. /// </summary> private static readonly Func<SyntaxTriviaList, int, SyntaxToken> s_findSkippedTokenBackward = FindSkippedTokenBackward; /// <summary> /// Look inside a trivia list for a skipped token that contains the given position. /// </summary> private static SyntaxToken FindSkippedTokenBackward(SyntaxTriviaList triviaList, int position) { foreach (var trivia in triviaList.Reverse()) { if (trivia.HasStructure) { if (trivia.GetStructure() is ISkippedTokensTriviaSyntax skippedTokensTrivia) { foreach (var token in skippedTokensTrivia.Tokens) { if (token.Span.Length > 0 && token.SpanStart <= position) { return token; } } } } } return default; } private static SyntaxToken GetInitialToken( SyntaxNode root, int position, bool includeSkipped = false, bool includeDirectives = false, bool includeDocumentationComments = false) { return (position < root.FullSpan.End || !(root is ICompilationUnitSyntax)) ? root.FindToken(position, includeSkipped || includeDirectives || includeDocumentationComments) : root.GetLastToken(includeZeroWidth: true, includeSkipped: true, includeDirectives: true, includeDocumentationComments: true) .GetPreviousToken(includeZeroWidth: false, includeSkipped: includeSkipped, includeDirectives: includeDirectives, includeDocumentationComments: includeDocumentationComments); } /// <summary> /// If the position is inside of token, return that token; otherwise, return the token to the right. /// </summary> public static SyntaxToken FindTokenOnRightOfPosition( this SyntaxNode root, int position, bool includeSkipped = false, bool includeDirectives = false, bool includeDocumentationComments = false) { var findSkippedToken = includeSkipped ? s_findSkippedTokenForward : ((l, p) => default); var token = GetInitialToken(root, position, includeSkipped, includeDirectives, includeDocumentationComments); if (position < token.SpanStart) { var skippedToken = findSkippedToken(token.LeadingTrivia, position); token = skippedToken.RawKind != 0 ? skippedToken : token; } else if (token.Span.End <= position) { do { var skippedToken = findSkippedToken(token.TrailingTrivia, position); token = skippedToken.RawKind != 0 ? skippedToken : token.GetNextToken(includeZeroWidth: false, includeSkipped: includeSkipped, includeDirectives: includeDirectives, includeDocumentationComments: includeDocumentationComments); } while (token.RawKind != 0 && token.Span.End <= position && token.Span.End <= root.FullSpan.End); } if (token.Span.Length == 0) { token = token.GetNextToken(); } return token; } /// <summary> /// If the position is inside of token, return that token; otherwise, return the token to the left. /// </summary> public static SyntaxToken FindTokenOnLeftOfPosition( this SyntaxNode root, int position, bool includeSkipped = false, bool includeDirectives = false, bool includeDocumentationComments = false) { var findSkippedToken = includeSkipped ? s_findSkippedTokenBackward : ((l, p) => default); var token = GetInitialToken(root, position, includeSkipped, includeDirectives, includeDocumentationComments); if (position <= token.SpanStart) { do { var skippedToken = findSkippedToken(token.LeadingTrivia, position); token = skippedToken.RawKind != 0 ? skippedToken : token.GetPreviousToken(includeZeroWidth: false, includeSkipped: includeSkipped, includeDirectives: includeDirectives, includeDocumentationComments: includeDocumentationComments); } while (position <= token.SpanStart && root.FullSpan.Start < token.SpanStart); } else if (token.Span.End < position) { var skippedToken = findSkippedToken(token.TrailingTrivia, position); token = skippedToken.RawKind != 0 ? skippedToken : token; } if (token.Span.Length == 0) { token = token.GetPreviousToken(); } return token; } public static T WithPrependedLeadingTrivia<T>( this T node, params SyntaxTrivia[] trivia) where T : SyntaxNode { if (trivia.Length == 0) { return node; } return node.WithPrependedLeadingTrivia((IEnumerable<SyntaxTrivia>)trivia); } public static T WithPrependedLeadingTrivia<T>( this T node, SyntaxTriviaList trivia) where T : SyntaxNode { if (trivia.Count == 0) { return node; } return node.WithLeadingTrivia(trivia.Concat(node.GetLeadingTrivia())); } public static T WithPrependedLeadingTrivia<T>( this T node, IEnumerable<SyntaxTrivia> trivia) where T : SyntaxNode { var list = new SyntaxTriviaList(); list = list.AddRange(trivia); return node.WithPrependedLeadingTrivia(list); } public static T WithAppendedTrailingTrivia<T>( this T node, params SyntaxTrivia[] trivia) where T : SyntaxNode { if (trivia.Length == 0) { return node; } return node.WithAppendedTrailingTrivia((IEnumerable<SyntaxTrivia>)trivia); } public static T WithAppendedTrailingTrivia<T>( this T node, SyntaxTriviaList trivia) where T : SyntaxNode { if (trivia.Count == 0) { return node; } return node.WithTrailingTrivia(node.GetTrailingTrivia().Concat(trivia)); } public static T WithAppendedTrailingTrivia<T>( this T node, IEnumerable<SyntaxTrivia> trivia) where T : SyntaxNode { var list = new SyntaxTriviaList(); list = list.AddRange(trivia); return node.WithAppendedTrailingTrivia(list); } public static T With<T>( this T node, IEnumerable<SyntaxTrivia> leadingTrivia, IEnumerable<SyntaxTrivia> trailingTrivia) where T : SyntaxNode { return node.WithLeadingTrivia(leadingTrivia).WithTrailingTrivia(trailingTrivia); } /// <summary> /// Creates a new token with the leading trivia removed. /// </summary> public static SyntaxToken WithoutLeadingTrivia(this SyntaxToken token) { return token.WithLeadingTrivia(default(SyntaxTriviaList)); } /// <summary> /// Creates a new token with the trailing trivia removed. /// </summary> public static SyntaxToken WithoutTrailingTrivia(this SyntaxToken token) { return token.WithTrailingTrivia(default(SyntaxTriviaList)); } /// <summary> /// Finds the node within the given <paramref name="root"/> corresponding to the given <paramref name="span"/>. /// If the <paramref name="span"/> is <see langword="null"/>, then returns the given <paramref name="root"/> node. /// </summary> public static SyntaxNode FindNode(this SyntaxNode root, TextSpan? span, bool findInTrivia, bool getInnermostNodeForTie) { return span.HasValue ? root.FindNode(span.Value, findInTrivia, getInnermostNodeForTie) : root; } // Copy of the same function in SyntaxNode.cs public static SyntaxNode? GetParent(this SyntaxNode node, bool ascendOutOfTrivia) { var parent = node.Parent; if (parent == null && ascendOutOfTrivia) { if (node is IStructuredTriviaSyntax structuredTrivia) { parent = structuredTrivia.ParentTrivia.Token.Parent; } } return parent; } public static TNode? FirstAncestorOrSelfUntil<TNode>(this SyntaxNode? node, Func<SyntaxNode, bool> predicate) where TNode : SyntaxNode { for (var current = node; current != null; current = current.GetParent(ascendOutOfTrivia: true)) { if (current is TNode tnode) { return tnode; } if (predicate(current)) { break; } } return null; } public static DirectiveInfo<TDirectiveTriviaSyntax> GetDirectiveInfoForRoot<TDirectiveTriviaSyntax>( SyntaxNode root, ISyntaxKinds syntaxKinds, CancellationToken cancellationToken) where TDirectiveTriviaSyntax : SyntaxNode { return DirectiveTriviaUtilities<TDirectiveTriviaSyntax>.GetDirectiveInfoForRoot(root, syntaxKinds, cancellationToken); } private static class DirectiveTriviaUtilities<TDirectiveTriviaSyntax> where TDirectiveTriviaSyntax : SyntaxNode { private sealed class DirectiveSyntaxEqualityComparer : IEqualityComparer<TDirectiveTriviaSyntax> { public static readonly DirectiveSyntaxEqualityComparer Instance = new(); private DirectiveSyntaxEqualityComparer() { } public bool Equals(TDirectiveTriviaSyntax? x, TDirectiveTriviaSyntax? y) => x?.SpanStart == y?.SpanStart; public int GetHashCode(TDirectiveTriviaSyntax obj) => obj.SpanStart; } private static readonly ObjectPool<Stack<TDirectiveTriviaSyntax>> s_stackPool = new(() => new()); public static DirectiveInfo<TDirectiveTriviaSyntax> GetDirectiveInfoForRoot( SyntaxNode root, ISyntaxKinds syntaxKinds, CancellationToken cancellationToken) { var directiveMap = new Dictionary<TDirectiveTriviaSyntax, TDirectiveTriviaSyntax?>( DirectiveSyntaxEqualityComparer.Instance); var conditionalMap = new Dictionary<TDirectiveTriviaSyntax, ImmutableArray<TDirectiveTriviaSyntax>>( DirectiveSyntaxEqualityComparer.Instance); using var _1 = s_stackPool.GetPooledObject(out var regionStack); using var _2 = s_stackPool.GetPooledObject(out var ifStack); foreach (var token in root.DescendantTokens(descendIntoChildren: static node => node.ContainsDirectives)) { cancellationToken.ThrowIfCancellationRequested(); if (!token.ContainsDirectives) continue; foreach (var trivia in token.LeadingTrivia) { if (trivia.RawKind == syntaxKinds.RegionDirectiveTrivia) { regionStack.Push((TDirectiveTriviaSyntax)trivia.GetStructure()!); } else if (trivia.RawKind == syntaxKinds.IfDirectiveTrivia || trivia.RawKind == syntaxKinds.ElifDirectiveTrivia || trivia.RawKind == syntaxKinds.ElseDirectiveTrivia) { ifStack.Push((TDirectiveTriviaSyntax)trivia.GetStructure()!); } else if (trivia.RawKind == syntaxKinds.EndRegionDirectiveTrivia) { if (regionStack.Count > 0) { var directive = (TDirectiveTriviaSyntax)trivia.GetStructure()!; var previousDirective = regionStack.Pop(); directiveMap.Add(directive, previousDirective); directiveMap.Add(previousDirective, directive); } } else if (trivia.RawKind == syntaxKinds.EndIfDirectiveTrivia) { if (ifStack.Count > 0) FinishIf((TDirectiveTriviaSyntax)trivia.GetStructure()!); } } } while (regionStack.Count > 0) directiveMap.Add(regionStack.Pop(), null); while (ifStack.Count > 0) FinishIf(directive: null); return new DirectiveInfo<TDirectiveTriviaSyntax>(directiveMap, conditionalMap); void FinishIf(TDirectiveTriviaSyntax? directive) { using var _ = ArrayBuilder<TDirectiveTriviaSyntax>.GetInstance(out var condDirectivesBuilder); if (directive != null) condDirectivesBuilder.Add(directive); while (ifStack.TryPop(out var poppedDirective)) { condDirectivesBuilder.Add(poppedDirective); if (poppedDirective.RawKind == syntaxKinds.IfDirectiveTrivia) break; } condDirectivesBuilder.Sort(static (n1, n2) => n1.SpanStart.CompareTo(n2.SpanStart)); var condDirectives = condDirectivesBuilder.ToImmutableAndClear(); foreach (var cond in condDirectives) conditionalMap.Add(cond, condDirectives); // #If should be the first one in sorted order var ifDirective = condDirectives.First(); if (directive != null) { directiveMap.Add(directive, ifDirective); directiveMap.Add(ifDirective, directive); } } } } public static SyntaxTriviaList GetLeadingTriviaIncludingMissingTokens(this SyntaxNode node) { var triviaList = new SyntaxTriviaList(); foreach (var token in node.DescendantTokens()) { triviaList = triviaList.AddRange(token.LeadingTrivia); if (!token.IsMissing) break; triviaList = triviaList.AddRange(token.TrailingTrivia); } return triviaList; } public static SyntaxTriviaList GetTrailingTriviaIncludingMissingTokens(this SyntaxNode node) { var triviaList = new SyntaxTriviaList(); var fullSpan = node.FullSpan; var current = node.GetLastToken(includeZeroWidth: true); while (fullSpan.Contains(current.Span)) { triviaList = triviaList.InsertRange(0, current.TrailingTrivia); if (!current.IsMissing) break; triviaList = triviaList.InsertRange(0, current.LeadingTrivia); current = current.GetPreviousToken(includeZeroWidth: true); } return triviaList; } public static SyntaxNode WithLeadingTriviaFromIncludingMissingTokens(this SyntaxNode node, SyntaxNode other) { return node.WithLeadingTrivia(other.GetLeadingTriviaIncludingMissingTokens()); } public static SyntaxNode WithTrailingTriviaFromIncludingMissingTokens(this SyntaxNode node, SyntaxNode other) { return node.WithTrailingTrivia(other.GetTrailingTriviaIncludingMissingTokens()); } public static SyntaxNode WithTriviaFromIncludingMissingTokens(this SyntaxNode node, SyntaxNode other) { return node.WithLeadingTriviaFromIncludingMissingTokens(other) .WithTrailingTriviaFromIncludingMissingTokens(other); } /// <summary> /// Gets a list of ancestor nodes (including this node) /// </summary> public static ValueAncestorsAndSelfEnumerable ValueAncestorsAndSelf(this SyntaxNode syntaxNode, bool ascendOutOfTrivia = true) => new(syntaxNode, ascendOutOfTrivia); public readonly struct ValueAncestorsAndSelfEnumerable(SyntaxNode syntaxNode, bool ascendOutOfTrivia) { public Enumerator GetEnumerator() => new(syntaxNode, ascendOutOfTrivia); public struct Enumerator(SyntaxNode syntaxNode, bool ascendOutOfTrivia) { public SyntaxNode Current { get; private set; } = null!; public bool MoveNext() { Current = Current == null ? syntaxNode : GetParent(Current, ascendOutOfTrivia)!; return Current != null; } } } } |