|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Syntax;
namespace Microsoft.CodeAnalysis.CSharp
{
/// <summary>
/// Represents a <see cref="CSharpSyntaxVisitor{TResult}"/> which descends an entire <see cref="CSharpSyntaxNode"/> graph and
/// may replace or remove visited SyntaxNodes in depth-first order.
/// </summary>
public abstract partial class CSharpSyntaxRewriter : CSharpSyntaxVisitor<SyntaxNode?>
{
private readonly bool _visitIntoStructuredTrivia;
public CSharpSyntaxRewriter(bool visitIntoStructuredTrivia = false)
{
_visitIntoStructuredTrivia = visitIntoStructuredTrivia;
}
public virtual bool VisitIntoStructuredTrivia
{
get { return _visitIntoStructuredTrivia; }
}
private int _recursionDepth;
[return: NotNullIfNotNull(nameof(node))]
public override SyntaxNode? Visit(SyntaxNode? node)
{
if (node != null)
{
_recursionDepth++;
StackGuard.EnsureSufficientExecutionStack(_recursionDepth);
var result = ((CSharpSyntaxNode)node).Accept(this);
_recursionDepth--;
// https://github.com/dotnet/roslyn/issues/47682
return result!;
}
else
{
return null;
}
}
public virtual SyntaxToken VisitToken(SyntaxToken token)
{
// PERF: This is a hot method, so it has been written to minimize the following:
// 1. Virtual method calls
// 2. Copying of structs
// 3. Repeated null checks
// PERF: Avoid testing node for null more than once
var node = token.Node;
if (node == null)
{
return token;
}
// PERF: Make one virtual method call each to get the leading and trailing trivia
var leadingTrivia = node.GetLeadingTriviaCore();
var trailingTrivia = node.GetTrailingTriviaCore();
// Trivia is either null or a non-empty list (there's no such thing as an empty green list)
Debug.Assert(leadingTrivia == null || !leadingTrivia.IsList || leadingTrivia.SlotCount > 0);
Debug.Assert(trailingTrivia == null || !trailingTrivia.IsList || trailingTrivia.SlotCount > 0);
if (leadingTrivia != null)
{
// PERF: Expand token.LeadingTrivia when node is not null.
var leading = this.VisitList(new SyntaxTriviaList(token, leadingTrivia));
if (trailingTrivia != null)
{
// Both leading and trailing trivia
// PERF: Expand token.TrailingTrivia when node is not null and leadingTrivia is not null.
// Also avoid node.Width because it makes a virtual call to GetText. Instead use node.FullWidth - trailingTrivia.FullWidth.
var index = leadingTrivia.IsList ? leadingTrivia.SlotCount : 1;
var trailing = this.VisitList(new SyntaxTriviaList(token, trailingTrivia, token.Position + node.FullWidth - trailingTrivia.FullWidth, index));
if (leading.Node != leadingTrivia)
{
token = token.WithLeadingTrivia(leading);
}
return trailing.Node != trailingTrivia ? token.WithTrailingTrivia(trailing) : token;
}
else
{
// Leading trivia only
return leading.Node != leadingTrivia ? token.WithLeadingTrivia(leading) : token;
}
}
else if (trailingTrivia != null)
{
// Trailing trivia only
// PERF: Expand token.TrailingTrivia when node is not null and leading is null.
// Also avoid node.Width because it makes a virtual call to GetText. Instead use node.FullWidth - trailingTrivia.FullWidth.
var trailing = this.VisitList(new SyntaxTriviaList(token, trailingTrivia, token.Position + node.FullWidth - trailingTrivia.FullWidth, index: 0));
return trailing.Node != trailingTrivia ? token.WithTrailingTrivia(trailing) : token;
}
else
{
// No trivia
return token;
}
}
public virtual SyntaxTrivia VisitTrivia(SyntaxTrivia trivia)
{
if (this.VisitIntoStructuredTrivia && trivia.HasStructure)
{
var structure = (CSharpSyntaxNode)trivia.GetStructure()!;
var newStructure = (StructuredTriviaSyntax?)this.Visit(structure);
if (newStructure != structure)
{
if (newStructure != null)
{
return SyntaxFactory.Trivia(newStructure);
}
else
{
return default;
}
}
}
return trivia;
}
public virtual SyntaxList<TNode> VisitList<TNode>(SyntaxList<TNode> list) where TNode : SyntaxNode
{
SyntaxListBuilder<TNode> alternate = default;
for (int i = 0, n = list.Count; i < n; i++)
{
var item = list[i];
var visited = this.VisitListElement(item);
if (item != visited && alternate.IsNull)
{
alternate = new SyntaxListBuilder<TNode>(n);
alternate.AddRange(list, 0, i);
}
if (!alternate.IsNull && visited != null && !visited.IsKind(SyntaxKind.None))
{
alternate.Add(visited);
}
}
if (!alternate.IsNull)
{
return alternate.ToList();
}
return list;
}
public virtual TNode? VisitListElement<TNode>(TNode? node) where TNode : SyntaxNode
{
return (TNode?)this.Visit(node);
}
public virtual SeparatedSyntaxList<TNode> VisitList<TNode>(SeparatedSyntaxList<TNode> list) where TNode : SyntaxNode
{
var count = list.Count;
var sepCount = list.SeparatorCount;
SeparatedSyntaxListBuilder<TNode> alternate = default;
int i = 0;
for (; i < sepCount; i++)
{
var node = list[i];
var visitedNode = this.VisitListElement(node);
var separator = list.GetSeparator(i);
var visitedSeparator = this.VisitListSeparator(separator);
if (alternate.IsNull)
{
if (node != visitedNode || separator != visitedSeparator)
{
alternate = new SeparatedSyntaxListBuilder<TNode>(count);
alternate.AddRange(list, i);
}
}
if (!alternate.IsNull)
{
if (visitedNode != null)
{
alternate.Add(visitedNode);
if (visitedSeparator.RawKind == 0)
{
throw new InvalidOperationException(CodeAnalysisResources.SeparatorIsExpected);
}
alternate.AddSeparator(visitedSeparator);
}
else
{
if (visitedNode == null)
{
throw new InvalidOperationException(CodeAnalysisResources.ElementIsExpected);
}
}
}
}
if (i < count)
{
var node = list[i];
var visitedNode = this.VisitListElement(node);
if (alternate.IsNull)
{
if (node != visitedNode)
{
alternate = new SeparatedSyntaxListBuilder<TNode>(count);
alternate.AddRange(list, i);
}
}
if (!alternate.IsNull && visitedNode != null)
{
alternate.Add(visitedNode);
}
}
if (!alternate.IsNull)
{
return alternate.ToList();
}
return list;
}
public virtual SyntaxToken VisitListSeparator(SyntaxToken separator)
{
return this.VisitToken(separator);
}
public virtual SyntaxTokenList VisitList(SyntaxTokenList list)
{
SyntaxTokenListBuilder? alternate = null;
var count = list.Count;
var index = -1;
foreach (var item in list)
{
index++;
var visited = this.VisitToken(item);
if (item != visited && alternate == null)
{
alternate = new SyntaxTokenListBuilder(count);
alternate.Add(list, 0, index);
}
if (alternate != null && visited.Kind() != SyntaxKind.None) //skip the null check since SyntaxToken is a value type
{
alternate.Add(visited);
}
}
if (alternate != null)
{
return alternate.ToList();
}
return list;
}
public virtual SyntaxTriviaList VisitList(SyntaxTriviaList list)
{
var count = list.Count;
if (count != 0)
{
SyntaxTriviaListBuilder? alternate = null;
var index = -1;
foreach (var item in list)
{
index++;
var visited = this.VisitListElement(item);
//skip the null check since SyntaxTrivia is a value type
if (visited != item && alternate == null)
{
alternate = new SyntaxTriviaListBuilder(count);
alternate.Add(list, 0, index);
}
if (alternate != null && visited.Kind() != SyntaxKind.None)
{
alternate.Add(visited);
}
}
if (alternate != null)
{
return alternate.ToList();
}
}
return list;
}
public virtual SyntaxTrivia VisitListElement(SyntaxTrivia element)
{
return this.VisitTrivia(element);
}
}
}
|