|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Threading;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
using TOutput = System.ValueTuple<System.Collections.Generic.IEnumerable<Microsoft.CodeAnalysis.GeneratedSourceText>, System.Collections.Generic.IEnumerable<Microsoft.CodeAnalysis.Diagnostic>>;
namespace Microsoft.CodeAnalysis
{
internal sealed class SourceOutputNode<TInput> : IIncrementalGeneratorOutputNode, IIncrementalGeneratorNode<TOutput>
{
private static readonly string? s_tableType = typeof(TOutput).FullName;
private readonly IIncrementalGeneratorNode<TInput> _source;
private readonly Action<SourceProductionContext, TInput, CancellationToken> _action;
private readonly IncrementalGeneratorOutputKind _outputKind;
private readonly string _sourceExtension;
public SourceOutputNode(IIncrementalGeneratorNode<TInput> source, Action<SourceProductionContext, TInput, CancellationToken> action, IncrementalGeneratorOutputKind outputKind, string sourceExtension)
{
_source = source;
_action = action;
Debug.Assert(outputKind == IncrementalGeneratorOutputKind.Source || outputKind == IncrementalGeneratorOutputKind.Implementation);
_outputKind = outputKind;
_sourceExtension = sourceExtension;
}
public IncrementalGeneratorOutputKind Kind => _outputKind;
public NodeStateTable<TOutput> UpdateStateTable(DriverStateTable.Builder graphState, NodeStateTable<TOutput>? previousTable, CancellationToken cancellationToken)
{
string stepName = Kind == IncrementalGeneratorOutputKind.Source ? WellKnownGeneratorOutputs.SourceOutput : WellKnownGeneratorOutputs.ImplementationSourceOutput;
var sourceTable = graphState.GetLatestStateTableForNode(_source);
if (sourceTable.IsCached && previousTable is not null)
{
this.LogTables(stepName, s_tableType, previousTable, previousTable, sourceTable);
if (graphState.DriverState.TrackIncrementalSteps)
{
return previousTable.CreateCachedTableWithUpdatedSteps(sourceTable, stepName, equalityComparer: null);
}
return previousTable;
}
var tableBuilder = graphState.CreateTableBuilder(previousTable, stepName, equalityComparer: null);
foreach (var entry in sourceTable)
{
var inputs = tableBuilder.TrackIncrementalSteps ? ImmutableArray.Create((entry.Step!, entry.OutputIndex)) : default;
if (entry.State == EntryState.Removed)
{
tableBuilder.TryRemoveEntries(TimeSpan.Zero, inputs);
}
else if (entry.State != EntryState.Cached || !tableBuilder.TryUseCachedEntries(TimeSpan.Zero, inputs))
{
var sourcesBuilder = new AdditionalSourcesCollection(_sourceExtension);
var diagnostics = DiagnosticBag.GetInstance();
SourceProductionContext context = new SourceProductionContext(sourcesBuilder, diagnostics, graphState.Compilation, cancellationToken);
try
{
var stopwatch = SharedStopwatch.StartNew();
_action(context, entry.Item, cancellationToken);
var sourcesAndDiagnostics = (sourcesBuilder.ToImmutable(), diagnostics.ToReadOnly());
if (entry.State != EntryState.Modified || !tableBuilder.TryModifyEntry(sourcesAndDiagnostics, stopwatch.Elapsed, inputs, entry.State))
{
tableBuilder.AddEntry(sourcesAndDiagnostics, EntryState.Added, stopwatch.Elapsed, inputs, EntryState.Added);
}
}
finally
{
sourcesBuilder.Free();
diagnostics.Free();
}
}
}
var newTable = tableBuilder.ToImmutableAndFree();
this.LogTables(stepName, s_tableType, previousTable, newTable, sourceTable);
return newTable;
}
IIncrementalGeneratorNode<TOutput> IIncrementalGeneratorNode<TOutput>.WithComparer(IEqualityComparer<TOutput> comparer) => throw ExceptionUtilities.Unreachable();
public IIncrementalGeneratorNode<(IEnumerable<GeneratedSourceText>, IEnumerable<Diagnostic>)> WithTrackingName(string name) => throw ExceptionUtilities.Unreachable();
void IIncrementalGeneratorNode<TOutput>.RegisterOutput(IIncrementalGeneratorOutputNode output) => throw ExceptionUtilities.Unreachable();
public void AppendOutputs(IncrementalExecutionContext context, CancellationToken cancellationToken)
{
// get our own state table
Debug.Assert(context.TableBuilder is not null);
var table = context.TableBuilder.GetLatestStateTableForNode(this);
// add each non-removed entry to the context
foreach (var ((sources, diagnostics), state, _, _) in table)
{
if (state != EntryState.Removed)
{
foreach (var text in sources)
{
try
{
context.Sources.Add(text.HintName, text.Text);
}
catch (ArgumentException e)
{
throw new UserFunctionException(e);
}
}
context.Diagnostics.AddRange(diagnostics);
}
}
if (context.GeneratorRunStateBuilder.RecordingExecutedSteps)
{
context.GeneratorRunStateBuilder.RecordStepsFromOutputNodeUpdate(table);
}
}
}
}
|