File: System\Linq\Concat.cs
Web Access
Project: src\src\libraries\System.Linq.AsyncEnumerable\src\System.Linq.AsyncEnumerable.csproj (System.Linq.AsyncEnumerable)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading.Tasks;
 
namespace System.Linq
{
    public static partial class AsyncEnumerable
    {
        /// <summary>Concatenates two sequences.</summary>
        /// <typeparam name="TSource">The type of the elements of the input sequences.</typeparam>
        /// <param name="first">The first sequence to concatenate.</param>
        /// <param name="second">The sequence to concatenate to the first sequence.</param>
        /// <returns>An <see cref="IAsyncEnumerable{T}"/> that contains the concatenated elements of the two input sequences.</returns>
        /// <exception cref="ArgumentNullException"><paramref name="first"/> is <see langword="null"/>.</exception>
        /// <exception cref="ArgumentNullException"><paramref name="second"/> is <see langword="null"/>.</exception>
        public static IAsyncEnumerable<TSource> Concat<TSource>(
            this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second)
        {
            ArgumentNullException.ThrowIfNull(first);
            ArgumentNullException.ThrowIfNull(second);
 
            if (first.IsKnownEmpty())
            {
                return second;
            }
 
            if (second.IsKnownEmpty())
            {
                return first;
            }
 
            return first is ConcatAsyncIterator<TSource> firstConcat
                ? firstConcat.Concat(second)
                : new Concat2AsyncIterator<TSource>(first, second);
        }
 
        /// <summary>
        /// Represents the concatenation of two <see cref="IAsyncEnumerable{TSource}"/>.
        /// </summary>
        /// <typeparam name="TSource">The type of the source enumerables.</typeparam>
        private sealed class Concat2AsyncIterator<TSource> : ConcatAsyncIterator<TSource>
        {
            /// <summary>
            /// The first source to concatenate.
            /// </summary>
            internal readonly IAsyncEnumerable<TSource> _first;
 
            /// <summary>
            /// The second source to concatenate.
            /// </summary>
            internal readonly IAsyncEnumerable<TSource> _second;
 
            /// <summary>
            /// Initializes a new instance of the <see cref="Concat2AsyncIterator{TSource}"/> class.
            /// </summary>
            /// <param name="first">The first source to concatenate.</param>
            /// <param name="second">The second source to concatenate.</param>
            internal Concat2AsyncIterator(IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second)
            {
                Debug.Assert(first is not null);
                Debug.Assert(second is not null);
 
                _first = first;
                _second = second;
            }
 
            private protected override AsyncIterator<TSource> Clone() => new Concat2AsyncIterator<TSource>(_first, _second);
 
            internal override ConcatAsyncIterator<TSource> Concat(IAsyncEnumerable<TSource> next)
            {
                return new ConcatNAsyncIterator<TSource>(this, next, 2);
            }
 
            internal override IAsyncEnumerable<TSource>? GetAsyncEnumerable(int index)
            {
                Debug.Assert(index >= 0);
 
                return index switch
                {
                    0 => _first,
                    1 => _second,
                    _ => null,
                };
            }
        }
 
        /// <summary>
        /// Represents the concatenation of three or more <see cref="IAsyncEnumerable{TSource}"/>.
        /// </summary>
        /// <typeparam name="TSource">The type of the source enumerables.</typeparam>
        /// <remarks>
        /// To handle chains of >= 3 sources, we chain the <see cref="Concat"/> iterators together and allow
        /// <see cref="GetAsyncEnumerable"/> to fetch enumerables from the previous sources.  This means that rather
        /// than each MoveNextAsync and Current calls having to traverse all of the previous
        /// sources, we only have to traverse all of the previous sources once per chained enumerable.  An alternative
        /// would be to use an array to store all of the enumerables, but this has a much better memory profile and
        /// without much additional run-time cost.
        /// </remarks>
        private sealed class ConcatNAsyncIterator<TSource> : ConcatAsyncIterator<TSource>
        {
            /// <summary>
            /// The linked list of previous sources.
            /// </summary>
            private readonly ConcatAsyncIterator<TSource> _tail;
 
            /// <summary>
            /// The source associated with this iterator.
            /// </summary>
            private readonly IAsyncEnumerable<TSource> _head;
 
            /// <summary>
            /// The logical index associated with this iterator.
            /// </summary>
            private readonly int _headIndex;
 
            /// <summary>
            /// Initializes a new instance of the <see cref="ConcatNAsyncIterator{TSource}"/> class.
            /// </summary>
            /// <param name="tail">The linked list of previous sources.</param>
            /// <param name="head">The source associated with this iterator.</param>
            /// <param name="headIndex">The logical index associated with this iterator.</param>
            internal ConcatNAsyncIterator(ConcatAsyncIterator<TSource> tail, IAsyncEnumerable<TSource> head, int headIndex)
            {
                Debug.Assert(tail is not null);
                Debug.Assert(head is not null);
                Debug.Assert(headIndex >= 2);
 
                _tail = tail;
                _head = head;
                _headIndex = headIndex;
            }
 
            private ConcatNAsyncIterator<TSource>? PreviousN => _tail as ConcatNAsyncIterator<TSource>;
 
            private protected override AsyncIterator<TSource> Clone() => new ConcatNAsyncIterator<TSource>(_tail, _head, _headIndex);
 
            internal override ConcatAsyncIterator<TSource> Concat(IAsyncEnumerable<TSource> next)
            {
                if (_headIndex == int.MaxValue - 2)
                {
                    // In the unlikely case of this many concatenations, if we produced a ConcatNAsyncIterator
                    // with int.MaxValue then state would overflow before it matched its index.
                    // So we use the naive approach of just having a left and right sequence.
                    return new Concat2AsyncIterator<TSource>(this, next);
                }
 
                return new ConcatNAsyncIterator<TSource>(this, next, _headIndex + 1);
            }
 
            internal override IAsyncEnumerable<TSource>? GetAsyncEnumerable(int index)
            {
                Debug.Assert(index >= 0);
 
                if (index > _headIndex)
                {
                    return null;
                }
 
                ConcatNAsyncIterator<TSource>? node, previousN = this;
                do
                {
                    node = previousN;
                    if (index == node._headIndex)
                    {
                        return node._head;
                    }
                }
                while ((previousN = node.PreviousN) is not null);
 
                Debug.Assert(index == 0 || index == 1);
                Debug.Assert(node._tail is Concat2AsyncIterator<TSource>);
                return node._tail.GetAsyncEnumerable(index);
            }
        }
 
        /// <summary>
        /// Represents the concatenation of two or more <see cref="IAsyncEnumerable{TSource}"/>.
        /// </summary>
        /// <typeparam name="TSource">The type of the source enumerables.</typeparam>
        private abstract class ConcatAsyncIterator<TSource> : AsyncIterator<TSource>
        {
            /// <summary>
            /// The enumerator of the current source, if <see cref="MoveNextAsync"/> has been called.
            /// </summary>
            private IAsyncEnumerator<TSource>? _enumerator;
 
            public override async ValueTask DisposeAsync()
            {
                if (_enumerator is not null)
                {
                    await _enumerator.DisposeAsync().ConfigureAwait(false);
                    _enumerator = null;
                }
 
                await base.DisposeAsync().ConfigureAwait(false);
            }
 
            /// <summary>
            /// Gets the enumerable at a logical index in this iterator.
            /// If the index is equal to the number of enumerables this iterator holds, <c>null</c> is returned.
            /// </summary>
            /// <param name="index">The logical index.</param>
            internal abstract IAsyncEnumerable<TSource>? GetAsyncEnumerable(int index);
 
            /// <summary>
            /// Creates a new iterator that concatenates this iterator with an enumerable.
            /// </summary>
            /// <param name="next">The next enumerable.</param>
            internal abstract ConcatAsyncIterator<TSource> Concat(IAsyncEnumerable<TSource> next);
 
            public override async ValueTask<bool> MoveNextAsync()
            {
                if (_state == 1)
                {
                    _enumerator = GetAsyncEnumerable(0)!.GetAsyncEnumerator(_cancellationToken);
                    _state = 2;
                }
 
                if (_state > 1)
                {
                    while (true)
                    {
                        Debug.Assert(_enumerator is not null);
                        if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
                        {
                            _current = _enumerator.Current;
                            return true;
                        }
 
                        IAsyncEnumerable<TSource>? next = GetAsyncEnumerable(_state++ - 1);
                        if (next is not null)
                        {
                            await _enumerator.DisposeAsync().ConfigureAwait(false);
                            _enumerator = next.GetAsyncEnumerator(_cancellationToken);
                            continue;
                        }
 
                        await DisposeAsync().ConfigureAwait(false);
                        break;
                    }
                }
 
                return false;
            }
        }
    }
}