File: SourceGeneration\Nodes\InputNode.cs
Web Access
Project: src\roslyn\src\Compilers\Core\Portable\Microsoft.CodeAnalysis.csproj (Microsoft.CodeAnalysis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Threading;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis
{
    /// <summary>
    /// Input nodes are the 'root' nodes in the graph, and get their values from the inputs of the driver state table
    /// </summary>
    /// <typeparam name="T">The type of the input</typeparam>
    internal sealed class InputNode<T> : IIncrementalGeneratorNode<T>
    {
        private static readonly string? s_tableType = typeof(T).FullName;

        private readonly Func<DriverStateTable.Builder, ImmutableArray<T>> _getInput;
        private readonly Action<IIncrementalGeneratorOutputNode> _registerOutput;
        private readonly IEqualityComparer<T>? _comparer;
        private readonly ObjectPool<PooledHashSet<T>>? _hashSetPool;
        private readonly string? _name;

        public InputNode(Func<DriverStateTable.Builder, ImmutableArray<T>> getInput, IEqualityComparer<T>? inputComparer = null)
            : this(getInput, registerOutput: null, CreateHashSetPool(inputComparer), comparer: null)
        {
        }

        private InputNode(
            Func<DriverStateTable.Builder, ImmutableArray<T>> getInput,
            Action<IIncrementalGeneratorOutputNode>? registerOutput,
            ObjectPool<PooledHashSet<T>>? hashSetPool,
            IEqualityComparer<T>? comparer = null,
            string? name = null)
        {
            _getInput = getInput;
            _comparer = comparer;
            _registerOutput = registerOutput ?? (o => throw ExceptionUtilities.Unreachable());
            _hashSetPool = hashSetPool;
            _name = name;
        }

        private static ObjectPool<PooledHashSet<T>>? CreateHashSetPool(IEqualityComparer<T>? inputComparer)
            => inputComparer is null || inputComparer == EqualityComparer<T>.Default
                ? null
                : PooledHashSet<T>.CreatePool(inputComparer);

        public NodeStateTable<T> UpdateStateTable(DriverStateTable.Builder graphState, NodeStateTable<T>? previousTable, CancellationToken cancellationToken)
        {
            var stopwatch = SharedStopwatch.StartNew();
            var inputItems = _getInput(graphState);
            TimeSpan elapsedTime = stopwatch.Elapsed;

            // create a mutable hashset of the new items we can check against
            var itemsSet = getPooledHashSet(inputItems.Length);
            NodeStateTable<T>.Builder? tableBuilder = null;

            try
            {
                foreach (var item in inputItems)
                {
                    var added = itemsSet.Add(item);
                    Debug.Assert(added);
                }

                tableBuilder = graphState.CreateTableBuilder(previousTable, _name, _comparer);

                // We always have no inputs steps into an InputNode, but we track the difference between "no inputs" (empty collection) and "no step information" (default value)
                var noInputStepsStepInfo = tableBuilder.TrackIncrementalSteps ? ImmutableArray<(IncrementalGeneratorRunStep, int)>.Empty : default;

                if (previousTable is not null)
                {
                    // When the item count is unchanged, instead of Remove+Add we can modify the removed item
                    // with a newly added item instead. This keeps the table the same size and avoids unnecessary
                    // downstream invalidation. The list of replacement items is computed lazily on first need
                    // so that pure reorders and identical inputs skip the work entirely.
                    var canReplace = inputItems.Length == previousTable.Count;
                    ImmutableArray<T> replacements = default;
                    int replacementIndex = 0;

                    foreach (var (oldItem, _, _, _) in previousTable)
                    {
                        if (itemsSet.Remove(oldItem))
                        {
                            // we're iterating the table, so know that it has entries
                            var usedCache = tableBuilder.TryUseCachedEntries(elapsedTime, noInputStepsStepInfo);
                            Debug.Assert(usedCache);
                        }
                        else if (canReplace)
                        {
                            if (replacements.IsDefault)
                                replacements = getNewInputItems(inputItems, previousTable);

                            // The old item was removed. Use the next pre-computed replacement to modify in place.
                            var replacementItem = replacements[replacementIndex++];
                            var removed = itemsSet.Remove(replacementItem);
                            Debug.Assert(removed);

                            var modified = tableBuilder.TryModifyEntry(replacementItem, elapsedTime, noInputStepsStepInfo, EntryState.Modified);
                            Debug.Assert(modified);
                        }
                        else
                        {
                            var removed = tableBuilder.TryRemoveEntries(elapsedTime, noInputStepsStepInfo);
                            Debug.Assert(removed);
                        }
                    }
                }

                // When the count is unchanged, every new item was consumed as either a cache hit or a
                // replacement above, so itemsSet is empty and this loop is a no-op. Otherwise, any items
                // remaining in itemsSet are genuinely new and need to be added.
                Debug.Assert(previousTable is null || inputItems.Length != previousTable.Count || itemsSet.Count == 0);
                foreach (var newItem in itemsSet)
                {
                    tableBuilder.AddEntry(newItem, EntryState.Added, elapsedTime, noInputStepsStepInfo, EntryState.Added);
                }

                var newTable = tableBuilder.ToImmutableAndFree();
                tableBuilder = null;
                this.LogTables(previousTable, newTable, inputItems);

                return newTable;
            }
            finally
            {
                tableBuilder?.Free();
                itemsSet.Free();
            }

            PooledHashSet<T> getPooledHashSet(int capacity)
            {
                var set = _hashSetPool?.Allocate() ?? PooledHashSet<T>.GetInstance();
#if NET
                set.EnsureCapacity(capacity);
#endif
                return set;
            }

            // Builds an array of new items (present in inputItems but not in the previous
            // table) in input order. These are used to populate a modified replacement slot
            // rather than a pair of add/remove entries.
            ImmutableArray<T> getNewInputItems(ImmutableArray<T> inputs, NodeStateTable<T> previous)
            {
                var previousItemsSet = getPooledHashSet(previous.Count);
                foreach (var (item, _, _, _) in previous)
                {
                    previousItemsSet.Add(item);
                }

                var builder = ArrayBuilder<T>.GetInstance();
                foreach (var item in inputs)
                {
                    if (!previousItemsSet.Contains(item))
                    {
                        builder.Add(item);
                    }
                }

                previousItemsSet.Free();
                return builder.ToImmutableAndFree();
            }
        }

        public IIncrementalGeneratorNode<T> WithComparer(IEqualityComparer<T> comparer) => new InputNode<T>(_getInput, _registerOutput, _hashSetPool, comparer, _name);

        public IIncrementalGeneratorNode<T> WithTrackingName(string name) => new InputNode<T>(_getInput, _registerOutput, _hashSetPool, _comparer, name);

        public InputNode<T> WithRegisterOutput(Action<IIncrementalGeneratorOutputNode> registerOutput) => new InputNode<T>(_getInput, registerOutput, _hashSetPool, _comparer, _name);

        public void RegisterOutput(IIncrementalGeneratorOutputNode output) => _registerOutput(output);

        private void LogTables(NodeStateTable<T>? previousTable, NodeStateTable<T> newTable, ImmutableArray<T> inputs)
        {
            if (!CodeAnalysisEventSource.Log.IsEnabled())
            {
                // don't bother building the dummy table if we're not going to log anyway
                return;
            }

            var tableBuilder = NodeStateTable<T>.Empty.ToBuilder(_name, stepTrackingEnabled: false, tableCapacity: inputs.Length);
            foreach (var input in inputs)
            {
                tableBuilder.AddEntry(input, EntryState.Added, TimeSpan.Zero, stepInputs: default, EntryState.Added);
            }
            var inputTable = tableBuilder.ToImmutableAndFree();

            this.LogTables(_name, s_tableType, previousTable, newTable, inputTable);
        }
    }
}