File: Infrastructure\TaskGenericsUtil.cs
Web Access
Project: src\src\JSInterop\Microsoft.JSInterop\src\Microsoft.JSInterop.csproj (Microsoft.JSInterop)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Concurrent;
using System.Globalization;
 
namespace Microsoft.JSInterop.Infrastructure;
 
internal static class TaskGenericsUtil
{
    private static readonly ConcurrentDictionary<Type, ITaskResultGetter> _cachedResultGetters
        = new ConcurrentDictionary<Type, ITaskResultGetter>();
 
    private static readonly ConcurrentDictionary<Type, ITcsResultSetter> _cachedResultSetters
        = new ConcurrentDictionary<Type, ITcsResultSetter>();
 
    public static void SetTaskCompletionSourceResult(object taskCompletionSource, object? result)
        => CreateResultSetter(taskCompletionSource).SetResult(taskCompletionSource, result);
 
    public static void SetTaskCompletionSourceException(object taskCompletionSource, Exception exception)
        => CreateResultSetter(taskCompletionSource).SetException(taskCompletionSource, exception);
 
    public static Type GetTaskCompletionSourceResultType(object taskCompletionSource)
        => CreateResultSetter(taskCompletionSource).ResultType;
 
    public static object? GetTaskResult(Task task)
    {
        var getter = _cachedResultGetters.GetOrAdd(task.GetType(), taskInstanceType =>
        {
            var resultType = GetTaskResultType(taskInstanceType);
            return resultType == null
                ? new VoidTaskResultGetter()
                : (ITaskResultGetter)Activator.CreateInstance(
                    typeof(TaskResultGetter<>).MakeGenericType(resultType))!;
        });
        return getter.GetResult(task);
    }
 
    private static Type? GetTaskResultType(Type taskType)
    {
        // It might be something derived from Task or Task<T>, so we have to scan
        // up the inheritance hierarchy to find the Task or Task<T>
        while (taskType != typeof(Task) &&
            (!taskType.IsGenericType || taskType.GetGenericTypeDefinition() != typeof(Task<>)))
        {
            taskType = taskType.BaseType
                ?? throw new ArgumentException($"The type '{taskType.FullName}' is not inherited from '{typeof(Task).FullName}'.");
        }
 
        return taskType.IsGenericType
            ? taskType.GetGenericArguments()[0]
            : null;
    }
 
    interface ITcsResultSetter
    {
        Type ResultType { get; }
        void SetResult(object taskCompletionSource, object? result);
        void SetException(object taskCompletionSource, Exception exception);
    }
 
    private interface ITaskResultGetter
    {
        object? GetResult(Task task);
    }
 
    private sealed class TaskResultGetter<T> : ITaskResultGetter
    {
        public object? GetResult(Task task) => ((Task<T>)task).Result!;
    }
 
    private sealed class VoidTaskResultGetter : ITaskResultGetter
    {
        public object? GetResult(Task task)
        {
            task.Wait(); // Throw if the task failed
            return null;
        }
    }
 
    private sealed class TcsResultSetter<T> : ITcsResultSetter
    {
        public Type ResultType => typeof(T);
 
        public void SetResult(object tcs, object? result)
        {
            var typedTcs = (TaskCompletionSource<T>)tcs;
 
            // If necessary, attempt a cast
            var typedResult = result is T resultT
                ? resultT
                : (T)Convert.ChangeType(result, typeof(T), CultureInfo.InvariantCulture)!;
 
            typedTcs.SetResult(typedResult!);
        }
 
        public void SetException(object tcs, Exception exception)
        {
            var typedTcs = (TaskCompletionSource<T>)tcs;
            typedTcs.SetException(exception);
        }
    }
 
    private static ITcsResultSetter CreateResultSetter(object taskCompletionSource)
    {
        return _cachedResultSetters.GetOrAdd(taskCompletionSource.GetType(), tcsType =>
        {
            var resultType = tcsType.GetGenericArguments()[0];
            return (ITcsResultSetter)Activator.CreateInstance(
                typeof(TcsResultSetter<>).MakeGenericType(resultType))!;
        });
    }
}