File: SourceGeneration\Nodes\CombineNode.cs
Web Access
Project: src\src\roslyn\src\Compilers\Core\Portable\Microsoft.CodeAnalysis.csproj (Microsoft.CodeAnalysis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Threading;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis
{
    internal sealed class CombineNode<TInput1, TInput2> : IIncrementalGeneratorNode<(TInput1, TInput2)>
    {
        private static readonly string? s_tableType = typeof((TInput1, TInput2)).FullName;

        private readonly IIncrementalGeneratorNode<TInput1> _input1;
        private readonly IIncrementalGeneratorNode<TInput2> _input2;
        private readonly IEqualityComparer<(TInput1, TInput2)>? _comparer;
        private readonly string? _name;

        public CombineNode(IIncrementalGeneratorNode<TInput1> input1, IIncrementalGeneratorNode<TInput2> input2, IEqualityComparer<(TInput1, TInput2)>? comparer = null, string? name = null)
        {
            _input1 = input1;
            _input2 = input2;
            _comparer = comparer;
            _name = name;
        }

        public NodeStateTable<(TInput1, TInput2)> UpdateStateTable(DriverStateTable.Builder graphState, NodeStateTable<(TInput1, TInput2)>? previousTable, CancellationToken cancellationToken)
        {
            // get both input tables
            var input1Table = graphState.GetLatestStateTableForNode(_input1);
            var input2Table = graphState.GetLatestStateTableForNode(_input2);

            if (input1Table.IsCached && input2Table.IsCached && previousTable is not null)
            {
                this.LogTables(_name, s_tableType, previousTable, previousTable, input1Table, input2Table);
                if (graphState.DriverState.TrackIncrementalSteps)
                {
                    return RecordStepsForCachedTable(graphState, previousTable, input1Table, input2Table);
                }
                return previousTable;
            }

            var totalEntryItemCount = input1Table.GetTotalEntryItemCount();
            var tableBuilder = graphState.CreateTableBuilder(previousTable, _name, _comparer, totalEntryItemCount);
            // Semantics of a join:
            //
            // When input1[i] is cached:
            //  - cached if input2 is also cached
            //  - modified otherwise
            // State of input1[i] otherwise.

            // get the input2 item
            var isInput2Cached = input2Table.IsCached;
            (TInput2 input2, IncrementalGeneratorRunStep? input2Step) = input2Table.Single();

            // append the input2 item to each item in input1 
            foreach (var entry1 in input1Table)
            {
                var stopwatch = SharedStopwatch.StartNew();

                var stepInputs = tableBuilder.TrackIncrementalSteps ? ImmutableArray.Create((entry1.Step!, entry1.OutputIndex), (input2Step!, 0)) : default;

                var state = (entry1.State, isInput2Cached) switch
                {
                    (EntryState.Cached, true) => EntryState.Cached,
                    (EntryState.Cached, false) => EntryState.Modified,
                    _ => entry1.State
                };

                var entry = (entry1.Item, input2);
                if (state != EntryState.Modified || _comparer is null || !tableBuilder.TryModifyEntry(entry, stopwatch.Elapsed, stepInputs, state))
                {
                    tableBuilder.AddEntry(entry, state, stopwatch.Elapsed, stepInputs, state);
                }
            }

            Debug.Assert(tableBuilder.Count == totalEntryItemCount);

            var newTable = tableBuilder.ToImmutableAndFree();
            this.LogTables(_name, s_tableType, previousTable, newTable, input1Table, input2Table);
            return newTable;
        }

        private NodeStateTable<(TInput1, TInput2)> RecordStepsForCachedTable(DriverStateTable.Builder graphState, NodeStateTable<(TInput1, TInput2)> previousTable, NodeStateTable<TInput1> input1Table, NodeStateTable<TInput2> input2Table)
        {
            Debug.Assert(input1Table.HasTrackedSteps && input2Table.IsCached);
            var builder = graphState.CreateTableBuilder(previousTable, _name, _comparer);
            (_, IncrementalGeneratorRunStep? input2Step) = input2Table.Single();
            foreach (var entry in input1Table)
            {
                var stepInputs = ImmutableArray.Create((entry.Step!, entry.OutputIndex), (input2Step!, 0));

                bool usedCachedEntry = builder.TryUseCachedEntries(TimeSpan.Zero, stepInputs);
                Debug.Assert(usedCachedEntry);
            }
            return builder.ToImmutableAndFree();
        }

        public IIncrementalGeneratorNode<(TInput1, TInput2)> WithComparer(IEqualityComparer<(TInput1, TInput2)> comparer) => new CombineNode<TInput1, TInput2>(_input1, _input2, comparer, _name);

        public IIncrementalGeneratorNode<(TInput1, TInput2)> WithTrackingName(string name) => new CombineNode<TInput1, TInput2>(_input1, _input2, _comparer, name);

        public void RegisterOutput(IIncrementalGeneratorOutputNode output)
        {
            // We have to call register on both branches of the join, as they may chain up to different input nodes
            _input1.RegisterOutput(output);
            _input2.RegisterOutput(output);
        }

    }
}