File: src\LegacySupport\TaskWaitAsync\TaskExtensions.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.Http.Diagnostics\Microsoft.Extensions.Http.Diagnostics.csproj (Microsoft.Extensions.Http.Diagnostics)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics.CodeAnalysis;
 
namespace System.Threading.Tasks;
 
/// <summary>
/// Provides a set of static methods for <see cref="Task"/>.
/// </summary>
[ExcludeFromCodeCoverage]
internal static class TaskExtensions
{
#pragma warning disable VSTHRD003 // Avoid awaiting foreign Tasks
    /// <summary>
    /// Gets a <see cref="Task{TResult}"/> that will complete when the <paramref name="task"/> completes or when the specified <paramref name="cancellationToken"/> has cancellation requested.
    /// </summary>
    /// <typeparam name="TResult">The type of the task result.</typeparam>
    /// <param name="task">The task to wait on for completion.</param>
    /// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for a cancellation request.</param>
    /// <returns>The <see cref="Task{TResult}"/> representing the asynchronous wait.</returns>
    public static Task<TResult> WaitAsync<TResult>(this Task<TResult> task, CancellationToken cancellationToken)
    {
        if (task.IsCompleted || (!cancellationToken.CanBeCanceled))
        {
            return task;
        }
 
        if (cancellationToken.IsCancellationRequested)
        {
            return Task.FromCanceled<TResult>(cancellationToken);
        }
 
        return WaitTaskAsync(task, cancellationToken);
    }
 
    private static async Task<TResult> WaitTaskAsync<TResult>(Task<TResult> task, CancellationToken cancellationToken)
    {
        var tcs = new TaskCompletionSource<TResult>(TaskCreationOptions.RunContinuationsAsynchronously);
        using (cancellationToken.Register(
                static state => ((TaskCompletionSource<TResult>)state!).SetCanceled(), tcs, false))
        {
            var t = await Task.WhenAny(task, tcs.Task).ConfigureAwait(false);
#pragma warning disable VSTHRD103 // Call async methods when in an async method
            return t.GetAwaiter().GetResult();
#pragma warning restore VSTHRD103 // Call async methods when in an async method
        }
    }
}