|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.Shared.Extensions;
internal static partial class SyntaxNodeExtensions
{
public static SyntaxNode GetRequiredParent(this SyntaxNode node)
=> node.Parent ?? throw new InvalidOperationException("Node's parent was null");
public static IEnumerable<SyntaxNodeOrToken> DepthFirstTraversal(this SyntaxNode node)
=> SyntaxNodeOrTokenExtensions.DepthFirstTraversal(node);
public static IEnumerable<SyntaxNode> DepthFirstTraversalNodes(this SyntaxNode node)
=> SyntaxNodeOrTokenExtensions.DepthFirstTraversalNodes(node);
public static IEnumerable<SyntaxNode> GetAncestors(this SyntaxNode node)
{
var current = node.Parent;
while (current != null)
{
yield return current;
current = current.GetParent(ascendOutOfTrivia: true);
}
}
public static IEnumerable<TNode> GetAncestors<TNode>(this SyntaxNode node)
where TNode : SyntaxNode
{
var current = node.Parent;
while (current != null)
{
if (current is TNode tNode)
{
yield return tNode;
}
current = current.GetParent(ascendOutOfTrivia: true);
}
}
public static TNode? GetAncestor<TNode>(this SyntaxNode node)
where TNode : SyntaxNode
{
var current = node.Parent;
while (current != null)
{
if (current is TNode tNode)
{
return tNode;
}
current = current.GetParent(ascendOutOfTrivia: true);
}
return null;
}
public static TNode? GetAncestorOrThis<TNode>(this SyntaxNode? node)
where TNode : SyntaxNode
{
return node?.GetAncestorsOrThis<TNode>().FirstOrDefault();
}
public static IEnumerable<TNode> GetAncestorsOrThis<TNode>(this SyntaxNode? node)
where TNode : SyntaxNode
{
var current = node;
while (current != null)
{
if (current is TNode tNode)
{
yield return tNode;
}
current = current.GetParent(ascendOutOfTrivia: true);
}
}
public static bool HasAncestor<TNode>(this SyntaxNode node)
where TNode : SyntaxNode
{
return node.GetAncestors<TNode>().Any();
}
public static IEnumerable<TSyntaxNode> Traverse<TSyntaxNode>(
this SyntaxNode node, TextSpan searchSpan, Func<SyntaxNode, bool> predicate)
where TSyntaxNode : SyntaxNode
{
Contract.ThrowIfNull(node);
var nodes = new LinkedList<SyntaxNode>();
nodes.AddFirst(node);
while (nodes.Count > 0)
{
var currentNode = nodes.First!.Value;
nodes.RemoveFirst();
if (currentNode != null && searchSpan.Contains(currentNode.FullSpan) && predicate(currentNode))
{
if (currentNode is TSyntaxNode tSyntax)
{
yield return tSyntax;
}
nodes.AddRangeAtHead(currentNode.ChildNodes());
}
}
}
public static bool CheckParent<T>([NotNullWhen(returnValue: true)] this SyntaxNode? node, Func<T, bool> valueChecker) where T : SyntaxNode
{
if (node?.Parent is not T parentNode)
{
return false;
}
return valueChecker(parentNode);
}
/// <summary>
/// Returns true if is a given token is a child token of a certain type of parent node.
/// </summary>
/// <typeparam name="TParent">The type of the parent node.</typeparam>
/// <param name="node">The node that we are testing.</param>
/// <param name="childGetter">A function that, when given the parent node, returns the child token we are interested in.</param>
public static bool IsChildNode<TParent>(this SyntaxNode node, Func<TParent, SyntaxNode?> childGetter)
where TParent : SyntaxNode
{
var ancestor = node.GetAncestor<TParent>();
if (ancestor == null)
{
return false;
}
var ancestorNode = childGetter(ancestor);
return node == ancestorNode;
}
/// <summary>
/// Returns true if this node is found underneath the specified child in the given parent.
/// </summary>
public static bool IsFoundUnder<TParent>(this SyntaxNode node, Func<TParent, SyntaxNode?> childGetter)
where TParent : SyntaxNode
{
var ancestor = node.GetAncestor<TParent>();
if (ancestor == null)
{
return false;
}
var child = childGetter(ancestor);
// See if node passes through child on the way up to ancestor.
return node.GetAncestorsOrThis<SyntaxNode>().Contains(child);
}
public static SyntaxNode GetCommonRoot(this SyntaxNode node1, SyntaxNode node2)
{
Contract.ThrowIfTrue(node1.RawKind == 0 || node2.RawKind == 0);
// find common starting node from two nodes. as long as two nodes belong to same tree, there must be at least
// one common root (Ex, compilation unit)
using var _ = PooledHashSet<SyntaxNode>.GetInstance(out var set);
set.AddRange(node2.GetAncestorsOrThis<SyntaxNode>());
foreach (var ancestor in node1.AncestorsAndSelf())
{
if (set.Contains(ancestor))
return ancestor;
}
throw ExceptionUtilities.Unreachable();
}
public static int Width(this SyntaxNode node)
=> node.Span.Length;
public static int FullWidth(this SyntaxNode node)
=> node.FullSpan.Length;
public static SyntaxNode? FindInnermostCommonNode(this IEnumerable<SyntaxNode> nodes, Func<SyntaxNode, bool> predicate)
=> nodes.FindInnermostCommonNode()?.FirstAncestorOrSelf(predicate);
public static SyntaxNode? FindInnermostCommonNode(this IEnumerable<SyntaxNode> nodes)
{
// Two collections we use to make this operation as efficient as possible. One is a
// stack of the current shared ancestor chain shared by all nodes so far. It starts
// with the full ancestor chain of the first node, and can only get smaller over time.
// It should be log(n) with the size of the tree as it's only storing a parent chain.
//
// The second is a set with the exact same contents as the array. It's used for O(1)
// lookups if a node is in the ancestor chain or not.
using var _1 = ArrayBuilder<SyntaxNode>.GetInstance(out var commonAncestorsStack);
using var _2 = PooledHashSet<SyntaxNode>.GetInstance(out var commonAncestorsSet);
var first = true;
foreach (var node in nodes)
{
// If we're just starting, initialize the ancestors set/array with the ancestors of
// this node.
if (first)
{
first = false;
foreach (var ancestor in node.ValueAncestorsAndSelf())
{
commonAncestorsSet.Add(ancestor);
commonAncestorsStack.Add(ancestor);
}
// Reverse the ancestors stack so that we go downwards with CompilationUnit at
// the start, and then go down to this starting node. This enables cheap
// popping later on.
commonAncestorsStack.ReverseContents();
continue;
}
// On a subsequent node, walk its ancestors to find the first match
var commonAncestor = FindCommonAncestor(node, commonAncestorsSet);
if (commonAncestor == null)
{
// So this shouldn't happen as long as the nodes are from the same tree. And
// the caller really shouldn't be calling from different trees. However, the
// previous impl supported that, so continue to have this behavior.
//
// If this doesn't fire, that means that all callers seem sane. If it does
// fire, we can relax this (but we should consider fixing the caller).
Debug.Fail("Could not find common ancestor.");
return null;
}
// Now remove everything in the ancestors array up to that common ancestor. This is
// generally quite efficient. Either we settle on a common node quickly. and don't
// need to do work here, or we keep tossing data from our common-ancestor scratch
// pad, making further work faster.
while (commonAncestorsStack.Count > 0 &&
commonAncestorsStack.Peek() != commonAncestor)
{
commonAncestorsSet.Remove(commonAncestorsStack.Peek());
commonAncestorsStack.Pop();
}
if (commonAncestorsStack.Count == 0)
{
// So this shouldn't happen as long as the nodes are from the same tree. And
// the caller really shouldn't be calling from different trees. However, the
// previous impl supported that, so continue to have this behavior.
//
// If this doesn't fire, that means that all callers seem sane. If it does
// fire, we can relax this (but we should consider fixing the caller).
Debug.Fail("Could not find common ancestor.");
return null;
}
}
// The common ancestor is the one at the end of the ancestor stack. This could be empty
// in the case where the caller passed in an empty enumerable of nodes.
return commonAncestorsStack.Count == 0 ? null : commonAncestorsStack.Peek();
// local functions
static SyntaxNode? FindCommonAncestor(SyntaxNode node, HashSet<SyntaxNode> commonAncestorsSet)
{
foreach (var ancestor in node.ValueAncestorsAndSelf())
{
if (commonAncestorsSet.Contains(ancestor))
return ancestor;
}
return null;
}
}
public static TSyntaxNode? FindInnermostCommonNode<TSyntaxNode>(this IEnumerable<SyntaxNode> nodes) where TSyntaxNode : SyntaxNode
=> (TSyntaxNode?)nodes.FindInnermostCommonNode(t => t is TSyntaxNode);
public static TextSpan GetContainedSpan(this IEnumerable<SyntaxNode> nodes)
{
Contract.ThrowIfNull(nodes);
Contract.ThrowIfFalse(nodes.Any());
var fullSpan = nodes.First().Span;
foreach (var node in nodes)
{
fullSpan = TextSpan.FromBounds(
Math.Min(fullSpan.Start, node.SpanStart),
Math.Max(fullSpan.End, node.Span.End));
}
return fullSpan;
}
public static bool OverlapsHiddenPosition(this SyntaxNode node, CancellationToken cancellationToken)
=> node.OverlapsHiddenPosition(node.Span, cancellationToken);
public static bool OverlapsHiddenPosition(this SyntaxNode node, TextSpan span, CancellationToken cancellationToken)
=> node.SyntaxTree.OverlapsHiddenPosition(span, cancellationToken);
public static bool OverlapsHiddenPosition(this SyntaxNode declaration, SyntaxNode startNode, SyntaxNode endNode, CancellationToken cancellationToken)
{
var start = startNode.Span.End;
var end = endNode.SpanStart;
var textSpan = TextSpan.FromBounds(start, end);
return declaration.OverlapsHiddenPosition(textSpan, cancellationToken);
}
public static IEnumerable<T> GetAnnotatedNodes<T>(this SyntaxNode node, SyntaxAnnotation syntaxAnnotation) where T : SyntaxNode
=> node.GetAnnotatedNodesAndTokens(syntaxAnnotation).Select(n => n.AsNode()).OfType<T>();
/// <summary>
/// Creates a new tree of nodes from the existing tree with the specified old nodes replaced with a newly computed nodes.
/// </summary>
/// <param name="root">The root of the tree that contains all the specified nodes.</param>
/// <param name="nodes">The nodes from the tree to be replaced.</param>
/// <param name="computeReplacementAsync">A function that computes a replacement node for
/// the argument nodes. The first argument is one of the original specified nodes. The second argument is
/// the same node possibly rewritten with replaced descendants.</param>
/// <param name="cancellationToken"></param>
public static Task<TRootNode> ReplaceNodesAsync<TRootNode>(
this TRootNode root,
IEnumerable<SyntaxNode> nodes,
Func<SyntaxNode, SyntaxNode, CancellationToken, Task<SyntaxNode>> computeReplacementAsync,
CancellationToken cancellationToken) where TRootNode : SyntaxNode
{
return root.ReplaceSyntaxAsync(
nodes: nodes, computeReplacementNodeAsync: computeReplacementAsync,
tokens: null, computeReplacementTokenAsync: null,
trivia: null, computeReplacementTriviaAsync: null,
cancellationToken: cancellationToken);
}
/// <summary>
/// Creates a new tree of tokens from the existing tree with the specified old tokens replaced with a newly computed tokens.
/// </summary>
/// <param name="root">The root of the tree that contains all the specified tokens.</param>
/// <param name="tokens">The tokens from the tree to be replaced.</param>
/// <param name="computeReplacementAsync">A function that computes a replacement token for
/// the argument tokens. The first argument is one of the originally specified tokens. The second argument is
/// the same token possibly rewritten with replaced trivia.</param>
/// <param name="cancellationToken"></param>
public static Task<TRootNode> ReplaceTokensAsync<TRootNode>(
this TRootNode root,
IEnumerable<SyntaxToken> tokens,
Func<SyntaxToken, SyntaxToken, CancellationToken, Task<SyntaxToken>> computeReplacementAsync,
CancellationToken cancellationToken) where TRootNode : SyntaxNode
{
return root.ReplaceSyntaxAsync(
nodes: null, computeReplacementNodeAsync: null,
tokens: tokens, computeReplacementTokenAsync: computeReplacementAsync,
trivia: null, computeReplacementTriviaAsync: null,
cancellationToken: cancellationToken);
}
public static Task<TRoot> ReplaceTriviaAsync<TRoot>(
this TRoot root,
IEnumerable<SyntaxTrivia> trivia,
Func<SyntaxTrivia, SyntaxTrivia, CancellationToken, Task<SyntaxTrivia>> computeReplacementAsync,
CancellationToken cancellationToken) where TRoot : SyntaxNode
{
return root.ReplaceSyntaxAsync(
nodes: null, computeReplacementNodeAsync: null,
tokens: null, computeReplacementTokenAsync: null,
trivia: trivia, computeReplacementTriviaAsync: computeReplacementAsync,
cancellationToken: cancellationToken);
}
public static async Task<TRoot> ReplaceSyntaxAsync<TRoot>(
this TRoot root,
IEnumerable<SyntaxNode>? nodes,
Func<SyntaxNode, SyntaxNode, CancellationToken, Task<SyntaxNode>>? computeReplacementNodeAsync,
IEnumerable<SyntaxToken>? tokens,
Func<SyntaxToken, SyntaxToken, CancellationToken, Task<SyntaxToken>>? computeReplacementTokenAsync,
IEnumerable<SyntaxTrivia>? trivia,
Func<SyntaxTrivia, SyntaxTrivia, CancellationToken, Task<SyntaxTrivia>>? computeReplacementTriviaAsync,
CancellationToken cancellationToken)
where TRoot : SyntaxNode
{
// index all nodes, tokens and trivia by the full spans they cover
var nodesToReplace = nodes != null ? nodes.ToDictionary(n => n.FullSpan) : [];
var tokensToReplace = tokens != null ? tokens.ToDictionary(t => t.FullSpan) : [];
var triviaToReplace = trivia != null ? trivia.ToDictionary(t => t.FullSpan) : [];
var nodeReplacements = new Dictionary<SyntaxNode, SyntaxNode>();
var tokenReplacements = new Dictionary<SyntaxToken, SyntaxToken>();
var triviaReplacements = new Dictionary<SyntaxTrivia, SyntaxTrivia>();
var retryAnnotations = new AnnotationTable<object>("RetryReplace");
var spans = new List<TextSpan>(nodesToReplace.Count + tokensToReplace.Count + triviaToReplace.Count);
spans.AddRange(nodesToReplace.Keys);
spans.AddRange(tokensToReplace.Keys);
spans.AddRange(triviaToReplace.Keys);
while (spans.Count > 0)
{
// sort the spans of the items to be replaced so we can tell if any overlap
spans.Sort((x, y) =>
{
// order by end offset, and then by length
var d = x.End - y.End;
if (d == 0)
{
d = x.Length - y.Length;
}
return d;
});
// compute replacements for all nodes that will go in the same batch
// only spans that do not overlap go in the same batch.
TextSpan previous = default;
foreach (var span in spans)
{
// only add to replacement map if we don't intersect with the previous node. This taken with the sort order
// should ensure that parent nodes are not processed in the same batch as child nodes.
if (previous == default || !previous.IntersectsWith(span))
{
if (nodesToReplace.TryGetValue(span, out var currentNode))
{
var original = (SyntaxNode?)retryAnnotations.GetAnnotations(currentNode).SingleOrDefault() ?? currentNode;
var newNode = await computeReplacementNodeAsync!(original, currentNode, cancellationToken).ConfigureAwait(false);
nodeReplacements[currentNode] = newNode;
}
else if (tokensToReplace.TryGetValue(span, out var currentToken))
{
var original = (SyntaxToken?)retryAnnotations.GetAnnotations(currentToken).SingleOrDefault() ?? currentToken;
var newToken = await computeReplacementTokenAsync!(original, currentToken, cancellationToken).ConfigureAwait(false);
tokenReplacements[currentToken] = newToken;
}
else if (triviaToReplace.TryGetValue(span, out var currentTrivia))
{
var original = (SyntaxTrivia?)retryAnnotations.GetAnnotations(currentTrivia).SingleOrDefault() ?? currentTrivia;
var newTrivia = await computeReplacementTriviaAsync!(original, currentTrivia, cancellationToken).ConfigureAwait(false);
triviaReplacements[currentTrivia] = newTrivia;
}
}
previous = span;
}
var retryNodes = false;
var retryTokens = false;
var retryTrivia = false;
// replace nodes in batch
// submit all nodes so we can annotate the ones we don't replace
root = root.ReplaceSyntax(
nodes: nodesToReplace.Values,
computeReplacementNode: (original, rewritten) =>
{
if (rewritten != original || !nodeReplacements.TryGetValue(original, out var replaced))
{
// the subtree did change, or we didn't have a replacement for it in this batch
// so we need to add an annotation so we can find this node again for the next batch.
replaced = retryAnnotations.WithAdditionalAnnotations(rewritten, original);
retryNodes = true;
}
return replaced;
},
tokens: tokensToReplace.Values,
computeReplacementToken: (original, rewritten) =>
{
if (rewritten != original || !tokenReplacements.TryGetValue(original, out var replaced))
{
// the subtree did change, or we didn't have a replacement for it in this batch
// so we need to add an annotation so we can find this node again for the next batch.
replaced = retryAnnotations.WithAdditionalAnnotations(rewritten, original);
retryTokens = true;
}
return replaced;
},
trivia: triviaToReplace.Values,
computeReplacementTrivia: (original, rewritten) =>
{
if (!triviaReplacements.TryGetValue(original, out var replaced))
{
// the subtree did change, or we didn't have a replacement for it in this batch
// so we need to add an annotation so we can find this node again for the next batch.
replaced = retryAnnotations.WithAdditionalAnnotations(rewritten, original);
retryTrivia = true;
}
return replaced;
});
nodesToReplace.Clear();
tokensToReplace.Clear();
triviaToReplace.Clear();
spans.Clear();
// prepare next batch out of all remaining annotated nodes
if (retryNodes)
{
nodesToReplace = retryAnnotations.GetAnnotatedNodes(root).ToDictionary(n => n.FullSpan);
spans.AddRange(nodesToReplace.Keys);
}
if (retryTokens)
{
tokensToReplace = retryAnnotations.GetAnnotatedTokens(root).ToDictionary(t => t.FullSpan);
spans.AddRange(tokensToReplace.Keys);
}
if (retryTrivia)
{
triviaToReplace = retryAnnotations.GetAnnotatedTrivia(root).ToDictionary(t => t.FullSpan);
spans.AddRange(triviaToReplace.Keys);
}
}
return root;
}
/// <summary>
/// Look inside a trivia list for a skipped token that contains the given position.
/// </summary>
private static readonly Func<SyntaxTriviaList, int, SyntaxToken> s_findSkippedTokenForward = FindSkippedTokenForward;
/// <summary>
/// Look inside a trivia list for a skipped token that contains the given position.
/// </summary>
private static SyntaxToken FindSkippedTokenForward(SyntaxTriviaList triviaList, int position)
{
foreach (var trivia in triviaList)
{
if (trivia.HasStructure)
{
if (trivia.GetStructure() is ISkippedTokensTriviaSyntax skippedTokensTrivia)
{
foreach (var token in skippedTokensTrivia.Tokens)
{
if (token.Span.Length > 0 && position <= token.Span.End)
{
return token;
}
}
}
}
}
return default;
}
/// <summary>
/// Look inside a trivia list for a skipped token that contains the given position.
/// </summary>
private static readonly Func<SyntaxTriviaList, int, SyntaxToken> s_findSkippedTokenBackward = FindSkippedTokenBackward;
/// <summary>
/// Look inside a trivia list for a skipped token that contains the given position.
/// </summary>
private static SyntaxToken FindSkippedTokenBackward(SyntaxTriviaList triviaList, int position)
{
foreach (var trivia in triviaList.Reverse())
{
if (trivia.HasStructure)
{
if (trivia.GetStructure() is ISkippedTokensTriviaSyntax skippedTokensTrivia)
{
foreach (var token in skippedTokensTrivia.Tokens)
{
if (token.Span.Length > 0 && token.SpanStart <= position)
{
return token;
}
}
}
}
}
return default;
}
private static SyntaxToken GetInitialToken(
SyntaxNode root,
int position,
bool includeSkipped = false,
bool includeDirectives = false,
bool includeDocumentationComments = false)
{
return (position < root.FullSpan.End || !(root is ICompilationUnitSyntax))
? root.FindToken(position, includeSkipped || includeDirectives || includeDocumentationComments)
: root.GetLastToken(includeZeroWidth: true, includeSkipped: true, includeDirectives: true, includeDocumentationComments: true)
.GetPreviousToken(includeZeroWidth: false, includeSkipped: includeSkipped, includeDirectives: includeDirectives, includeDocumentationComments: includeDocumentationComments);
}
/// <summary>
/// If the position is inside of token, return that token; otherwise, return the token to the right.
/// </summary>
public static SyntaxToken FindTokenOnRightOfPosition(
this SyntaxNode root,
int position,
bool includeSkipped = false,
bool includeDirectives = false,
bool includeDocumentationComments = false)
{
var findSkippedToken = includeSkipped ? s_findSkippedTokenForward : ((l, p) => default);
var token = GetInitialToken(root, position, includeSkipped, includeDirectives, includeDocumentationComments);
if (position < token.SpanStart)
{
var skippedToken = findSkippedToken(token.LeadingTrivia, position);
token = skippedToken.RawKind != 0 ? skippedToken : token;
}
else if (token.Span.End <= position)
{
do
{
var skippedToken = findSkippedToken(token.TrailingTrivia, position);
token = skippedToken.RawKind != 0
? skippedToken
: token.GetNextToken(includeZeroWidth: false, includeSkipped: includeSkipped, includeDirectives: includeDirectives, includeDocumentationComments: includeDocumentationComments);
}
while (token.RawKind != 0 && token.Span.End <= position && token.Span.End <= root.FullSpan.End);
}
if (token.Span.Length == 0)
{
token = token.GetNextToken();
}
return token;
}
/// <summary>
/// If the position is inside of token, return that token; otherwise, return the token to the left.
/// </summary>
public static SyntaxToken FindTokenOnLeftOfPosition(
this SyntaxNode root,
int position,
bool includeSkipped = false,
bool includeDirectives = false,
bool includeDocumentationComments = false)
{
var findSkippedToken = includeSkipped ? s_findSkippedTokenBackward : ((l, p) => default);
var token = GetInitialToken(root, position, includeSkipped, includeDirectives, includeDocumentationComments);
if (position <= token.SpanStart)
{
do
{
var skippedToken = findSkippedToken(token.LeadingTrivia, position);
token = skippedToken.RawKind != 0
? skippedToken
: token.GetPreviousToken(includeZeroWidth: false, includeSkipped: includeSkipped, includeDirectives: includeDirectives, includeDocumentationComments: includeDocumentationComments);
}
while (position <= token.SpanStart && root.FullSpan.Start < token.SpanStart);
}
else if (token.Span.End < position)
{
var skippedToken = findSkippedToken(token.TrailingTrivia, position);
token = skippedToken.RawKind != 0 ? skippedToken : token;
}
if (token.Span.Length == 0)
{
token = token.GetPreviousToken();
}
return token;
}
public static T WithPrependedLeadingTrivia<T>(
this T node,
params SyntaxTrivia[] trivia) where T : SyntaxNode
{
if (trivia.Length == 0)
{
return node;
}
return node.WithPrependedLeadingTrivia((IEnumerable<SyntaxTrivia>)trivia);
}
public static T WithPrependedLeadingTrivia<T>(
this T node,
SyntaxTriviaList trivia) where T : SyntaxNode
{
if (trivia.Count == 0)
{
return node;
}
return node.WithLeadingTrivia(trivia.Concat(node.GetLeadingTrivia()));
}
public static T WithPrependedLeadingTrivia<T>(
this T node,
IEnumerable<SyntaxTrivia> trivia) where T : SyntaxNode
{
var list = new SyntaxTriviaList();
list = list.AddRange(trivia);
return node.WithPrependedLeadingTrivia(list);
}
public static T WithAppendedTrailingTrivia<T>(
this T node,
params SyntaxTrivia[] trivia) where T : SyntaxNode
{
if (trivia.Length == 0)
{
return node;
}
return node.WithAppendedTrailingTrivia((IEnumerable<SyntaxTrivia>)trivia);
}
public static T WithAppendedTrailingTrivia<T>(
this T node,
SyntaxTriviaList trivia) where T : SyntaxNode
{
if (trivia.Count == 0)
{
return node;
}
return node.WithTrailingTrivia(node.GetTrailingTrivia().Concat(trivia));
}
public static T WithAppendedTrailingTrivia<T>(
this T node,
IEnumerable<SyntaxTrivia> trivia) where T : SyntaxNode
{
var list = new SyntaxTriviaList();
list = list.AddRange(trivia);
return node.WithAppendedTrailingTrivia(list);
}
public static T With<T>(
this T node,
IEnumerable<SyntaxTrivia> leadingTrivia,
IEnumerable<SyntaxTrivia> trailingTrivia) where T : SyntaxNode
{
return node.WithLeadingTrivia(leadingTrivia).WithTrailingTrivia(trailingTrivia);
}
/// <summary>
/// Creates a new token with the leading trivia removed.
/// </summary>
public static SyntaxToken WithoutLeadingTrivia(this SyntaxToken token)
{
return token.WithLeadingTrivia(default(SyntaxTriviaList));
}
/// <summary>
/// Creates a new token with the trailing trivia removed.
/// </summary>
public static SyntaxToken WithoutTrailingTrivia(this SyntaxToken token)
{
return token.WithTrailingTrivia(default(SyntaxTriviaList));
}
/// <summary>
/// Finds the node within the given <paramref name="root"/> corresponding to the given <paramref name="span"/>.
/// If the <paramref name="span"/> is <see langword="null"/>, then returns the given <paramref name="root"/> node.
/// </summary>
public static SyntaxNode FindNode(this SyntaxNode root, TextSpan? span, bool findInTrivia, bool getInnermostNodeForTie)
{
return span.HasValue
? root.FindNode(span.Value, findInTrivia, getInnermostNodeForTie)
: root;
}
// Copy of the same function in SyntaxNode.cs
public static SyntaxNode? GetParent(this SyntaxNode node, bool ascendOutOfTrivia)
{
var parent = node.Parent;
if (parent == null && ascendOutOfTrivia)
{
if (node is IStructuredTriviaSyntax structuredTrivia)
{
parent = structuredTrivia.ParentTrivia.Token.Parent;
}
}
return parent;
}
public static TNode? FirstAncestorOrSelfUntil<TNode>(this SyntaxNode? node, Func<SyntaxNode, bool> predicate)
where TNode : SyntaxNode
{
for (var current = node; current != null; current = current.GetParent(ascendOutOfTrivia: true))
{
if (current is TNode tnode)
{
return tnode;
}
if (predicate(current))
{
break;
}
}
return null;
}
public static DirectiveInfo<TDirectiveTriviaSyntax> GetDirectiveInfoForRoot<TDirectiveTriviaSyntax>(
SyntaxNode root,
ISyntaxKinds syntaxKinds,
CancellationToken cancellationToken)
where TDirectiveTriviaSyntax : SyntaxNode
{
return DirectiveTriviaUtilities<TDirectiveTriviaSyntax>.GetDirectiveInfoForRoot(root, syntaxKinds, cancellationToken);
}
private static class DirectiveTriviaUtilities<TDirectiveTriviaSyntax>
where TDirectiveTriviaSyntax : SyntaxNode
{
private sealed class DirectiveSyntaxEqualityComparer : IEqualityComparer<TDirectiveTriviaSyntax>
{
public static readonly DirectiveSyntaxEqualityComparer Instance = new();
private DirectiveSyntaxEqualityComparer()
{
}
public bool Equals(TDirectiveTriviaSyntax? x, TDirectiveTriviaSyntax? y)
=> x?.SpanStart == y?.SpanStart;
public int GetHashCode(TDirectiveTriviaSyntax obj)
=> obj.SpanStart;
}
private static readonly ObjectPool<Stack<TDirectiveTriviaSyntax>> s_stackPool = new(() => new());
public static DirectiveInfo<TDirectiveTriviaSyntax> GetDirectiveInfoForRoot(
SyntaxNode root,
ISyntaxKinds syntaxKinds,
CancellationToken cancellationToken)
{
var directiveMap = new Dictionary<TDirectiveTriviaSyntax, TDirectiveTriviaSyntax?>(
DirectiveSyntaxEqualityComparer.Instance);
var conditionalMap = new Dictionary<TDirectiveTriviaSyntax, ImmutableArray<TDirectiveTriviaSyntax>>(
DirectiveSyntaxEqualityComparer.Instance);
using var _1 = s_stackPool.GetPooledObject(out var regionStack);
using var _2 = s_stackPool.GetPooledObject(out var ifStack);
foreach (var token in root.DescendantTokens(descendIntoChildren: static node => node.ContainsDirectives))
{
cancellationToken.ThrowIfCancellationRequested();
if (!token.ContainsDirectives)
continue;
foreach (var trivia in token.LeadingTrivia)
{
if (trivia.RawKind == syntaxKinds.RegionDirectiveTrivia)
{
regionStack.Push((TDirectiveTriviaSyntax)trivia.GetStructure()!);
}
else if (trivia.RawKind == syntaxKinds.IfDirectiveTrivia ||
trivia.RawKind == syntaxKinds.ElifDirectiveTrivia ||
trivia.RawKind == syntaxKinds.ElseDirectiveTrivia)
{
ifStack.Push((TDirectiveTriviaSyntax)trivia.GetStructure()!);
}
else if (trivia.RawKind == syntaxKinds.EndRegionDirectiveTrivia)
{
if (regionStack.Count > 0)
{
var directive = (TDirectiveTriviaSyntax)trivia.GetStructure()!;
var previousDirective = regionStack.Pop();
directiveMap.Add(directive, previousDirective);
directiveMap.Add(previousDirective, directive);
}
}
else if (trivia.RawKind == syntaxKinds.EndIfDirectiveTrivia)
{
if (ifStack.Count > 0)
FinishIf((TDirectiveTriviaSyntax)trivia.GetStructure()!);
}
}
}
while (regionStack.Count > 0)
directiveMap.Add(regionStack.Pop(), null);
while (ifStack.Count > 0)
FinishIf(directive: null);
return new DirectiveInfo<TDirectiveTriviaSyntax>(directiveMap, conditionalMap);
void FinishIf(TDirectiveTriviaSyntax? directive)
{
using var _ = ArrayBuilder<TDirectiveTriviaSyntax>.GetInstance(out var condDirectivesBuilder);
if (directive != null)
condDirectivesBuilder.Add(directive);
while (ifStack.TryPop(out var poppedDirective))
{
condDirectivesBuilder.Add(poppedDirective);
if (poppedDirective.RawKind == syntaxKinds.IfDirectiveTrivia)
break;
}
condDirectivesBuilder.Sort(static (n1, n2) => n1.SpanStart.CompareTo(n2.SpanStart));
var condDirectives = condDirectivesBuilder.ToImmutableAndClear();
foreach (var cond in condDirectives)
conditionalMap.Add(cond, condDirectives);
// #If should be the first one in sorted order
var ifDirective = condDirectives.First();
if (directive != null)
{
directiveMap.Add(directive, ifDirective);
directiveMap.Add(ifDirective, directive);
}
}
}
}
/// <summary>
/// Gets a list of ancestor nodes (including this node)
/// </summary>
public static ValueAncestorsAndSelfEnumerable ValueAncestorsAndSelf(this SyntaxNode syntaxNode, bool ascendOutOfTrivia = true)
=> new(syntaxNode, ascendOutOfTrivia);
public readonly struct ValueAncestorsAndSelfEnumerable(SyntaxNode syntaxNode, bool ascendOutOfTrivia)
{
public Enumerator GetEnumerator()
=> new(syntaxNode, ascendOutOfTrivia);
public struct Enumerator(SyntaxNode syntaxNode, bool ascendOutOfTrivia)
{
public SyntaxNode Current { get; private set; } = null!;
public bool MoveNext()
{
Current = Current == null ? syntaxNode : GetParent(Current, ascendOutOfTrivia)!;
return Current != null;
}
}
}
}
|