|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.PooledObjects;
namespace Microsoft.CodeAnalysis.Shared.Extensions;
internal static class IAsyncEnumerableExtensions
{
internal static class AsyncEnumerable<T>
{
public static readonly IAsyncEnumerable<T> Empty = GetEmptyAsync();
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
private static async IAsyncEnumerable<T> GetEmptyAsync()
{
yield break;
}
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
internal static async IAsyncEnumerable<T> SingletonAsync<T>(T value)
{
yield return value;
}
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
public static async Task<ImmutableArray<T>> ToImmutableArrayAsync<T>(this IAsyncEnumerable<T> values, CancellationToken cancellationToken)
{
using var _ = ArrayBuilder<T>.GetInstance(out var result);
await foreach (var value in values.WithCancellation(cancellationToken).ConfigureAwait(false))
result.Add(value);
return result.ToImmutableAndClear();
}
/// <summary>
/// Takes an array of <see cref="IAsyncEnumerable{T}"/>s and produces a single resultant <see
/// cref="IAsyncEnumerable{T}"/> with all their values merged together. Absolutely no ordering guarantee is
/// provided. It will be expected that the individual values from distinct enumerables will be interleaved
/// together.
/// </summary>
/// <remarks>This helper is useful when doign parallel processing of work where each job returns an <see
/// cref="IAsyncEnumerable{T}"/>, but one final stream is desired as the result.</remarks>
public static IAsyncEnumerable<T> MergeAsync<T>(this ImmutableArray<IAsyncEnumerable<T>> streams, CancellationToken cancellationToken)
{
// Code provided by Stephen Toub, but heavily modified after that.
// 1024 chosen as a way to ensure we don't necessarily create a huge unbounded channel, while also making it
// so that we're unlikely to throttle on any stream unless there is truly a huge amount of results in it.
var channel = Channel.CreateBounded<T>(1024);
var tasks = new Task[streams.Length];
for (var i = 0; i < streams.Length; i++)
tasks[i] = ProcessAsync(streams[i], channel.Writer, cancellationToken);
// Complete the channel writer with the result of all the tasks. If nothing failed, t.Exception will be
// null and this will complete successfully. If anything failed, the exception will propagate out.
//
// Note: passing CancellationToken.None here is intentional/correct. We must complete all the channels to
// allow reading to complete as well.
Task.WhenAll(tasks).CompletesChannel(channel);
return channel.Reader.ReadAllAsync(cancellationToken);
static async Task ProcessAsync(IAsyncEnumerable<T> stream, ChannelWriter<T> writer, CancellationToken cancellationToken)
{
await foreach (var value in stream)
await writer.WriteAsync(value, cancellationToken).ConfigureAwait(false);
}
}
public static async IAsyncEnumerable<T> ReadAllAsync<T>(
this ChannelReader<T> reader, [EnumeratorCancellation] CancellationToken cancellationToken)
{
while (await reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
while (reader.TryRead(out var item))
yield return item;
}
}
/// <summary>
/// Runs after task completes in any fashion (success, cancellation, faulting) and ensures the channel writer is
/// always completed. If the task faults then the exception from that task will be used to complete the channel
/// </summary>
public static void CompletesChannel<T>(this Task task, Channel<T> channel)
{
// Note: using `Complete(task.Exception)` is always fine. Exception is only produced in the case of
// faulting. it is null otherwise.
task.ContinueWith(
static (task, channel) => ((Channel<T>)channel!).Writer.Complete(task.Exception),
channel, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default);
}
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
#pragma warning disable VSTHRD200 // Use "Async" suffix for async methods
public static async IAsyncEnumerable<TSource> AsAsyncEnumerable<TSource>(this IEnumerable<TSource> source)
#pragma warning restore VSTHRD200 // Use "Async" suffix for async methods
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
{
foreach (var item in source)
yield return item;
}
}
|