File: src\SignalR\common\Shared\ClientResultsManager.cs
Web Access
Project: src\src\SignalR\server\Core\src\Microsoft.AspNetCore.SignalR.Core.csproj (Microsoft.AspNetCore.SignalR.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.AspNetCore.SignalR.Protocol;
 
namespace Microsoft.AspNetCore.SignalR.Internal;
 
// Common type used by our HubLifetimeManager implementations to manage client results.
// Handles cancellation, cleanup, and completion, so any bugs or improvements can be made in a single place
internal sealed class ClientResultsManager : IInvocationBinder
{
    private readonly ConcurrentDictionary<string, (Type Type, string ConnectionId, object Tcs, Action<object, CompletionMessage> Complete)> _pendingInvocations = new();
 
    public Task<T> AddInvocation<T>(string connectionId, string invocationId, CancellationToken cancellationToken)
    {
        var tcs = new TaskCompletionSourceWithCancellation<T>(this, connectionId, invocationId, cancellationToken);
        var result = _pendingInvocations.TryAdd(invocationId, (typeof(T), connectionId, tcs, static (state, completionMessage) =>
        {
            var tcs = (TaskCompletionSourceWithCancellation<T>)state;
            if (completionMessage.HasResult)
            {
                tcs.SetResult((T)completionMessage.Result!);
            }
            else
            {
                tcs.SetException(new HubException(completionMessage.Error));
            }
        }
        ));
        Debug.Assert(result);
 
        tcs.RegisterCancellation();
 
        return tcs.Task;
    }
 
    public void AddInvocation(string invocationId, (Type Type, string ConnectionId, object Tcs, Action<object, CompletionMessage> Complete) invocationInfo)
    {
        var result = _pendingInvocations.TryAdd(invocationId, invocationInfo);
        Debug.Assert(result);
        // Should have a 50% chance of happening once every 2.71 quintillion invocations (see UUID in Wikipedia)
        if (!result)
        {
            invocationInfo.Complete(invocationInfo.Tcs, CompletionMessage.WithError(invocationId, "ID collision occurred when using client results. This is likely a bug in SignalR."));
        }
    }
 
    public void TryCompleteResult(string connectionId, CompletionMessage message)
    {
        if (_pendingInvocations.TryGetValue(message.InvocationId!, out var item))
        {
            if (item.ConnectionId != connectionId)
            {
                throw new InvalidOperationException($"Connection ID '{connectionId}' is not valid for invocation ID '{message.InvocationId}'.");
            }
 
            // if false the connection disconnected right after the above TryGetValue
            // or someone else completed the invocation (likely a bad client)
            // we'll ignore both cases
            if (_pendingInvocations.Remove(message.InvocationId!, out _))
            {
                item.Complete(item.Tcs, message);
            }
        }
        else
        {
            // connection was disconnected or someone else completed the invocation
        }
    }
 
    public (Type Type, string ConnectionId, object Tcs, Action<object, CompletionMessage> Completion)? RemoveInvocation(string invocationId)
    {
        _pendingInvocations.TryRemove(invocationId, out var item);
        return item;
    }
 
    public bool TryGetType(string invocationId, [NotNullWhen(true)] out Type? type)
    {
        if (_pendingInvocations.TryGetValue(invocationId, out var item))
        {
            type = item.Type;
            return true;
        }
        type = null;
        return false;
    }
 
    public Type GetReturnType(string invocationId)
    {
        if (TryGetType(invocationId, out var type))
        {
            return type;
        }
        throw new InvalidOperationException($"Invocation ID '{invocationId}' is not associated with a pending client result.");
    }
 
    // Unused, here to honor the IInvocationBinder interface but should never be called
    public IReadOnlyList<Type> GetParameterTypes(string methodName)
    {
        throw new NotImplementedException();
    }
 
    // Unused, here to honor the IInvocationBinder interface but should never be called
    public Type GetStreamItemType(string streamId)
    {
        throw new NotImplementedException();
    }
 
    // Custom TCS type to avoid the extra allocation that would be introduced if we managed the cancellation separately
    // Also makes it easier to keep track of the CancellationTokenRegistration for disposal
    internal sealed class TaskCompletionSourceWithCancellation<T> : TaskCompletionSource<T>
    {
        private readonly ClientResultsManager _clientResultsManager;
        private readonly string _connectionId;
        private readonly string _invocationId;
        private readonly CancellationToken _token;
 
        private CancellationTokenRegistration _tokenRegistration;
 
        public TaskCompletionSourceWithCancellation(ClientResultsManager clientResultsManager, string connectionId, string invocationId,
            CancellationToken cancellationToken)
            : base(TaskCreationOptions.RunContinuationsAsynchronously)
        {
            _clientResultsManager = clientResultsManager;
            _connectionId = connectionId;
            _invocationId = invocationId;
            _token = cancellationToken;
        }
 
        // Needs to be called after adding the completion to the dictionary in order to avoid synchronous completions of the token registration
        // not canceling when the dictionary hasn't been updated yet.
        public void RegisterCancellation()
        {
            if (_token.CanBeCanceled)
            {
                _tokenRegistration = _token.UnsafeRegister(static o =>
                {
                    var tcs = (TaskCompletionSourceWithCancellation<T>)o!;
                    tcs.SetCanceled();
                }, this);
            }
        }
 
        public new void SetCanceled()
        {
            // TODO: RedisHubLifetimeManager will want to notify the other server (if there is one) about the cancellation
            // so it can clean up state and potentially forward that info to the connection
            _clientResultsManager.TryCompleteResult(_connectionId, CompletionMessage.WithError(_invocationId, "Invocation canceled by the server."));
        }
 
        public new void SetResult(T result)
        {
            _tokenRegistration.Dispose();
            base.SetResult(result);
        }
 
        public new void SetException(Exception exception)
        {
            _tokenRegistration.Dispose();
            base.SetException(exception);
        }
 
#pragma warning disable IDE0060 // Remove unused parameter
        // Just making sure we don't accidentally call one of these without knowing
        public static new void SetCanceled(CancellationToken cancellationToken) => Debug.Assert(false);
        public static new void SetException(IEnumerable<Exception> exceptions) => Debug.Assert(false);
        public static new bool TrySetCanceled()
        {
            Debug.Assert(false);
            return false;
        }
        public static new bool TrySetCanceled(CancellationToken cancellationToken)
        {
            Debug.Assert(false);
            return false;
        }
        public static new bool TrySetException(IEnumerable<Exception> exceptions)
        {
            Debug.Assert(false);
            return false;
        }
        public static new bool TrySetException(Exception exception)
        {
            Debug.Assert(false);
            return false;
        }
        public static new bool TrySetResult(T result)
        {
            Debug.Assert(false);
            return false;
        }
#pragma warning restore IDE0060 // Remove unused parameter
    }
}