File: SourceGeneration\Nodes\CombineNode.cs
Web Access
Project: src\src\Compilers\Core\Portable\Microsoft.CodeAnalysis.csproj (Microsoft.CodeAnalysis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Threading;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis
{
    internal sealed class CombineNode<TInput1, TInput2> : IIncrementalGeneratorNode<(TInput1, TInput2)>
    {
        private static readonly string? s_tableType = typeof((TInput1, TInput2)).FullName;
 
        private readonly IIncrementalGeneratorNode<TInput1> _input1;
        private readonly IIncrementalGeneratorNode<TInput2> _input2;
        private readonly IEqualityComparer<(TInput1, TInput2)>? _comparer;
        private readonly string? _name;
 
        public CombineNode(IIncrementalGeneratorNode<TInput1> input1, IIncrementalGeneratorNode<TInput2> input2, IEqualityComparer<(TInput1, TInput2)>? comparer = null, string? name = null)
        {
            _input1 = input1;
            _input2 = input2;
            _comparer = comparer;
            _name = name;
        }
 
        public NodeStateTable<(TInput1, TInput2)> UpdateStateTable(DriverStateTable.Builder graphState, NodeStateTable<(TInput1, TInput2)>? previousTable, CancellationToken cancellationToken)
        {
            // get both input tables
            var input1Table = graphState.GetLatestStateTableForNode(_input1);
            var input2Table = graphState.GetLatestStateTableForNode(_input2);
 
            if (input1Table.IsCached && input2Table.IsCached && previousTable is not null)
            {
                this.LogTables(_name, s_tableType, previousTable, previousTable, input1Table, input2Table);
                if (graphState.DriverState.TrackIncrementalSteps)
                {
                    return RecordStepsForCachedTable(graphState, previousTable, input1Table, input2Table);
                }
                return previousTable;
            }
 
            var totalEntryItemCount = input1Table.GetTotalEntryItemCount();
            var tableBuilder = graphState.CreateTableBuilder(previousTable, _name, _comparer, totalEntryItemCount);
            // Semantics of a join:
            //
            // When input1[i] is cached:
            //  - cached if input2 is also cached
            //  - modified otherwise
            // State of input1[i] otherwise.
 
            // get the input2 item
            var isInput2Cached = input2Table.IsCached;
            (TInput2 input2, IncrementalGeneratorRunStep? input2Step) = input2Table.Single();
 
            // append the input2 item to each item in input1 
            foreach (var entry1 in input1Table)
            {
                var stopwatch = SharedStopwatch.StartNew();
 
                var stepInputs = tableBuilder.TrackIncrementalSteps ? ImmutableArray.Create((entry1.Step!, entry1.OutputIndex), (input2Step!, 0)) : default;
 
                var state = (entry1.State, isInput2Cached) switch
                {
                    (EntryState.Cached, true) => EntryState.Cached,
                    (EntryState.Cached, false) => EntryState.Modified,
                    _ => entry1.State
                };
 
                var entry = (entry1.Item, input2);
                if (state != EntryState.Modified || _comparer is null || !tableBuilder.TryModifyEntry(entry, _comparer, stopwatch.Elapsed, stepInputs, state))
                {
                    tableBuilder.AddEntry(entry, state, stopwatch.Elapsed, stepInputs, state);
                }
            }
 
            Debug.Assert(tableBuilder.Count == totalEntryItemCount);
 
            var newTable = tableBuilder.ToImmutableAndFree();
            this.LogTables(_name, s_tableType, previousTable, newTable, input1Table, input2Table);
            return newTable;
        }
 
        private NodeStateTable<(TInput1, TInput2)> RecordStepsForCachedTable(DriverStateTable.Builder graphState, NodeStateTable<(TInput1, TInput2)> previousTable, NodeStateTable<TInput1> input1Table, NodeStateTable<TInput2> input2Table)
        {
            Debug.Assert(input1Table.HasTrackedSteps && input2Table.IsCached);
            var builder = graphState.CreateTableBuilder(previousTable, _name, _comparer);
            (_, IncrementalGeneratorRunStep? input2Step) = input2Table.Single();
            foreach (var entry in input1Table)
            {
                var stepInputs = ImmutableArray.Create((entry.Step!, entry.OutputIndex), (input2Step!, 0));
 
                bool usedCachedEntry = builder.TryUseCachedEntries(TimeSpan.Zero, stepInputs);
                Debug.Assert(usedCachedEntry);
            }
            return builder.ToImmutableAndFree();
        }
 
        public IIncrementalGeneratorNode<(TInput1, TInput2)> WithComparer(IEqualityComparer<(TInput1, TInput2)> comparer) => new CombineNode<TInput1, TInput2>(_input1, _input2, comparer, _name);
 
        public IIncrementalGeneratorNode<(TInput1, TInput2)> WithTrackingName(string name) => new CombineNode<TInput1, TInput2>(_input1, _input2, _comparer, name);
 
        public void RegisterOutput(IIncrementalGeneratorOutputNode output)
        {
            // We have to call register on both branches of the join, as they may chain up to different input nodes
            _input1.RegisterOutput(output);
            _input2.RegisterOutput(output);
        }
 
    }
}