File: Utilities\TaskExtensions.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.Evaluation\Microsoft.Extensions.AI.Evaluation.csproj (Microsoft.Extensions.AI.Evaluation)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
 
namespace Microsoft.Extensions.AI.Evaluation.Utilities;
 
internal static class TaskExtensions
{
    internal static IAsyncEnumerable<T> ExecuteConcurrentlyAndStreamResultsAsync<T>(
        this IEnumerable<Func<CancellationToken, Task<T>>> functions,
        bool preserveOrder = false,
        CancellationToken cancellationToken = default)
    {
        IEnumerable<Task<T>> concurrentTasks = functions.Select(f => f(cancellationToken));
        return concurrentTasks.StreamResultsAsync(preserveOrder, cancellationToken);
    }
 
    internal static IAsyncEnumerable<T> ExecuteConcurrentlyAndStreamResultsAsync<T>(
        this IEnumerable<Func<CancellationToken, ValueTask<T>>> functions,
        bool preserveOrder = false,
        CancellationToken cancellationToken = default)
    {
        IEnumerable<ValueTask<T>> concurrentTasks = functions.Select(f => f(cancellationToken));
        return concurrentTasks.StreamResultsAsync(preserveOrder, cancellationToken);
    }
 
    /// <remarks>
    /// <para>
    /// This method assumes that all the tasks supplied via <paramref name="concurrentTasks"/> are already running.
    /// </para>
    /// <para>
    /// Ideally, the <see cref="CancellationToken"/> passed via <paramref name="cancellationToken"/> should also cancel
    /// the tasks supplied via <paramref name="concurrentTasks"/>.
    /// </para>
    /// </remarks>
    internal static async IAsyncEnumerable<T> StreamResultsAsync<T>(
        this IEnumerable<Task<T>> concurrentTasks,
        bool preserveOrder = false,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        if (preserveOrder)
        {
            foreach (Task<T> task in concurrentTasks)
            {
                cancellationToken.ThrowIfCancellationRequested();
 
                yield return await task.ConfigureAwait(false);
            }
        }
        else
        {
#if NET9_0_OR_GREATER
            await foreach (Task<T> task in
                Task.WhenEach(concurrentTasks).WithCancellation(cancellationToken).ConfigureAwait(false))
            {
                yield return await task.ConfigureAwait(false);
            }
#else
            var remaining = new HashSet<Task<T>>(concurrentTasks);
 
            while (remaining.Count is not 0)
            {
                cancellationToken.ThrowIfCancellationRequested();
 
                var task = await Task.WhenAny(remaining).ConfigureAwait(false);
                _ = remaining.Remove(task);
                yield return await task.ConfigureAwait(false);
            }
#endif
        }
    }
 
    internal static async IAsyncEnumerable<T> StreamResultsAsync<T>(
        this IEnumerable<ValueTask<T>> concurrentTasks,
        bool preserveOrder = false,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        if (preserveOrder)
        {
            foreach (ValueTask<T> task in concurrentTasks)
            {
                cancellationToken.ThrowIfCancellationRequested();
 
                yield return await task.ConfigureAwait(false);
            }
        }
        else
        {
            IAsyncEnumerable<T> results =
                StreamResultsAsync(concurrentTasks.Select(t => t.AsTask()), preserveOrder, cancellationToken);
 
            await foreach (T result in results.ConfigureAwait(false))
            {
                yield return result;
            }
        }
    }
}