File: SourceGeneration\Nodes\InputNode.cs
Web Access
Project: src\src\Compilers\Core\Portable\Microsoft.CodeAnalysis.csproj (Microsoft.CodeAnalysis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Threading;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis
{
    /// <summary>
    /// Input nodes are the 'root' nodes in the graph, and get their values from the inputs of the driver state table
    /// </summary>
    /// <typeparam name="T">The type of the input</typeparam>
    internal sealed class InputNode<T> : IIncrementalGeneratorNode<T>
    {
        private static readonly string? s_tableType = typeof(T).FullName;
 
        private readonly Func<DriverStateTable.Builder, ImmutableArray<T>> _getInput;
        private readonly Action<IIncrementalGeneratorOutputNode> _registerOutput;
        private readonly IEqualityComparer<T> _inputComparer;
        private readonly IEqualityComparer<T> _comparer;
        private readonly string? _name;
 
        public InputNode(Func<DriverStateTable.Builder, ImmutableArray<T>> getInput, IEqualityComparer<T>? inputComparer = null)
            : this(getInput, registerOutput: null, inputComparer: inputComparer, comparer: null)
        {
        }
 
        private InputNode(Func<DriverStateTable.Builder, ImmutableArray<T>> getInput, Action<IIncrementalGeneratorOutputNode>? registerOutput, IEqualityComparer<T>? inputComparer = null, IEqualityComparer<T>? comparer = null, string? name = null)
        {
            _getInput = getInput;
            _comparer = comparer ?? EqualityComparer<T>.Default;
            _inputComparer = inputComparer ?? EqualityComparer<T>.Default;
            _registerOutput = registerOutput ?? (o => throw ExceptionUtilities.Unreachable());
            _name = name;
        }
 
        public NodeStateTable<T> UpdateStateTable(DriverStateTable.Builder graphState, NodeStateTable<T>? previousTable, CancellationToken cancellationToken)
        {
            var stopwatch = SharedStopwatch.StartNew();
            var inputItems = _getInput(graphState);
            TimeSpan elapsedTime = stopwatch.Elapsed;
 
            // create a mutable hashset of the new items we can check against
            var itemsSet = (_inputComparer == EqualityComparer<T>.Default) ? PooledHashSet<T>.GetInstance() : new HashSet<T>(_inputComparer);
 
#if NET
            itemsSet.EnsureCapacity(inputItems.Length);
#endif
 
            foreach (var item in inputItems)
            {
                var added = itemsSet.Add(item);
                Debug.Assert(added);
            }
 
            var tableBuilder = graphState.CreateTableBuilder(previousTable, _name, _comparer);
 
            // We always have no inputs steps into an InputNode, but we track the difference between "no inputs" (empty collection) and "no step information" (default value)
            var noInputStepsStepInfo = tableBuilder.TrackIncrementalSteps ? ImmutableArray<(IncrementalGeneratorRunStep, int)>.Empty : default;
 
            if (previousTable is not null)
            {
                // for each item in the previous table, check if its still in the new items
                int itemIndex = 0;
                foreach (var (oldItem, _, _, _) in previousTable)
                {
                    if (itemsSet.Remove(oldItem))
                    {
                        // we're iterating the table, so know that it has entries
                        var usedCache = tableBuilder.TryUseCachedEntries(elapsedTime, noInputStepsStepInfo);
                        Debug.Assert(usedCache);
                    }
                    else if (inputItems.Length == previousTable.Count)
                    {
                        // When the number of items matches the previous iteration, we use a heuristic to mark the input as modified
                        // This allows us to correctly 'replace' items even when they aren't actually the same. In the case that the
                        // item really isn't modified, but a new item, we still function correctly as we mostly treat them the same,
                        // but will perform an extra comparison that is omitted in the pure 'added' case.
                        var modified = tableBuilder.TryModifyEntry(inputItems[itemIndex], _comparer, elapsedTime, noInputStepsStepInfo, EntryState.Modified);
                        Debug.Assert(modified);
                        itemsSet.Remove(inputItems[itemIndex]);
                    }
                    else
                    {
                        var removed = tableBuilder.TryRemoveEntries(elapsedTime, noInputStepsStepInfo);
                        Debug.Assert(removed);
                    }
                    itemIndex++;
                }
            }
 
            // any remaining new items are added
            foreach (var newItem in itemsSet)
            {
                tableBuilder.AddEntry(newItem, EntryState.Added, elapsedTime, noInputStepsStepInfo, EntryState.Added);
            }
 
            var newTable = tableBuilder.ToImmutableAndFree();
            this.LogTables(previousTable, newTable, inputItems);
 
            (itemsSet as PooledHashSet<T>)?.Free();
 
            return newTable;
 
        }
 
        public IIncrementalGeneratorNode<T> WithComparer(IEqualityComparer<T> comparer) => new InputNode<T>(_getInput, _registerOutput, _inputComparer, comparer, _name);
 
        public IIncrementalGeneratorNode<T> WithTrackingName(string name) => new InputNode<T>(_getInput, _registerOutput, _inputComparer, _comparer, name);
 
        public InputNode<T> WithRegisterOutput(Action<IIncrementalGeneratorOutputNode> registerOutput) => new InputNode<T>(_getInput, registerOutput, _inputComparer, _comparer, _name);
 
        public void RegisterOutput(IIncrementalGeneratorOutputNode output) => _registerOutput(output);
 
        private void LogTables(NodeStateTable<T>? previousTable, NodeStateTable<T> newTable, ImmutableArray<T> inputs)
        {
            if (!CodeAnalysisEventSource.Log.IsEnabled())
            {
                // don't bother building the dummy table if we're not going to log anyway
                return;
            }
 
            var tableBuilder = NodeStateTable<T>.Empty.ToBuilder(_name, stepTrackingEnabled: false, tableCapacity: inputs.Length);
            foreach (var input in inputs)
            {
                tableBuilder.AddEntry(input, EntryState.Added, TimeSpan.Zero, stepInputs: default, EntryState.Added);
            }
            var inputTable = tableBuilder.ToImmutableAndFree();
 
            this.LogTables(_name, s_tableType, previousTable, newTable, inputTable);
        }
    }
}