File: Internal\InvocationRequest.cs
Web Access
Project: src\src\SignalR\clients\csharp\Client.Core\src\Microsoft.AspNetCore.SignalR.Client.Core.csproj (Microsoft.AspNetCore.SignalR.Client.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.Logging;
 
namespace Microsoft.AspNetCore.SignalR.Client.Internal;
 
internal abstract partial class InvocationRequest : IDisposable
{
    private readonly CancellationTokenRegistration _cancellationTokenRegistration;
    private int _isActivityStopping;
 
    protected ILogger Logger { get; }
 
    public Type ResultType { get; }
    public CancellationToken CancellationToken { get; }
    public string InvocationId { get; }
    public HubConnection HubConnection { get; }
    public Activity? Activity { get; }
 
    protected InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILogger logger, HubConnection hubConnection, Activity? activity)
    {
        _cancellationTokenRegistration = cancellationToken.Register(self => ((InvocationRequest)self!).Cancel(), this);
 
        InvocationId = invocationId;
        CancellationToken = cancellationToken;
        ResultType = resultType;
        Logger = logger;
        HubConnection = hubConnection;
        Activity = activity;
 
        Log.InvocationCreated(Logger, InvocationId);
    }
 
    [MemberNotNullWhen(true, nameof(Activity))]
    protected bool TryBeginStopActivity()
    {
        return Activity != null && Interlocked.Exchange(ref _isActivityStopping, 1) == 0;
    }
 
    public static InvocationRequest Invoke(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection, Activity? activity, out Task<object?> result)
    {
        var req = new NonStreaming(cancellationToken, resultType, invocationId, loggerFactory, hubConnection, activity);
        result = req.Result;
        return req;
    }
 
    public static InvocationRequest Stream(CancellationToken cancellationToken, Type resultType, string invocationId,
        ILoggerFactory loggerFactory, HubConnection hubConnection, Activity? activity, out ChannelReader<object?> result)
    {
        var req = new Streaming(cancellationToken, resultType, invocationId, loggerFactory, hubConnection, activity);
        result = req.Result;
        return req;
    }
 
    public abstract void Fail(Exception exception);
    public abstract void Complete(CompletionMessage message);
    public abstract ValueTask<bool> StreamItem(object? item);
 
    protected abstract void Cancel();
 
    public virtual void Dispose()
    {
        Log.InvocationDisposed(Logger, InvocationId);
 
        // Just in case it hasn't already been completed
        Cancel();
 
        _cancellationTokenRegistration.Dispose();
    }
 
    private sealed class Streaming : InvocationRequest
    {
        private readonly Channel<object?> _channel = Channel.CreateUnbounded<object?>();
 
        public Streaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection, Activity? activity)
            : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger(typeof(Streaming)), hubConnection, activity)
        {
        }
 
        public ChannelReader<object?> Result => _channel.Reader;
 
        public override void Complete(CompletionMessage completionMessage)
        {
            Log.InvocationCompleted(Logger, InvocationId);
            if (completionMessage.Result != null)
            {
                Log.ReceivedUnexpectedComplete(Logger, InvocationId);
 
                if (TryBeginStopActivity())
                {
                    Activity.SetStatus(ActivityStatusCode.Error);
                    Activity.Stop();
                }
 
                _channel.Writer.TryComplete(new InvalidOperationException("Server provided a result in a completion response to a streamed invocation."));
                return;
            }
 
            if (!string.IsNullOrEmpty(completionMessage.Error))
            {
                Fail(new HubException(completionMessage.Error));
                return;
            }
 
            if (TryBeginStopActivity())
            {
                Activity.Stop();
            }
 
            _channel.Writer.TryComplete();
        }
 
        public override void Fail(Exception exception)
        {
            Log.InvocationFailed(Logger, InvocationId);
 
            if (TryBeginStopActivity())
            {
                Activity.SetStatus(ActivityStatusCode.Error);
                Activity.SetTag("error.type", exception.GetType().FullName);
                Activity.Stop();
            }
 
            _channel.Writer.TryComplete(exception);
        }
 
        public override async ValueTask<bool> StreamItem(object? item)
        {
            try
            {
                while (!_channel.Writer.TryWrite(item))
                {
                    if (!await _channel.Writer.WaitToWriteAsync().ConfigureAwait(false))
                    {
                        return false;
                    }
                }
            }
            catch (Exception ex)
            {
                Log.ErrorWritingStreamItem(Logger, InvocationId, ex);
            }
            return true;
        }
 
        protected override void Cancel()
        {
            if (TryBeginStopActivity())
            {
                Activity.SetStatus(ActivityStatusCode.Error);
                Activity.SetTag("error.type", typeof(OperationCanceledException).FullName);
                Activity.Stop();
            }
 
            _channel.Writer.TryComplete(new OperationCanceledException());
        }
    }
 
    private sealed class NonStreaming : InvocationRequest
    {
        private readonly TaskCompletionSource<object?> _completionSource = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
 
        public NonStreaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection, Activity? activity)
            : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger(typeof(NonStreaming)), hubConnection, activity)
        {
        }
 
        public Task<object?> Result => _completionSource.Task;
 
        public override void Complete(CompletionMessage completionMessage)
        {
            if (!string.IsNullOrEmpty(completionMessage.Error))
            {
                Fail(new HubException(completionMessage.Error));
                return;
            }
 
            Log.InvocationCompleted(Logger, InvocationId);
 
            if (TryBeginStopActivity())
            {
                Activity.Stop();
            }
 
            _completionSource.TrySetResult(completionMessage.Result);
        }
 
        public override void Fail(Exception exception)
        {
            Log.InvocationFailed(Logger, InvocationId);
 
            if (TryBeginStopActivity())
            {
                Activity.SetStatus(ActivityStatusCode.Error);
                Activity.SetTag("error.type", exception.GetType().FullName);
                Activity.Stop();
            }
 
            _completionSource.TrySetException(exception);
        }
 
        public override ValueTask<bool> StreamItem(object? item)
        {
            Log.StreamItemOnNonStreamInvocation(Logger, InvocationId);
            _completionSource.TrySetException(new InvalidOperationException($"Streaming hub methods must be invoked with the '{nameof(HubConnection)}.{nameof(HubConnectionExtensions.StreamAsChannelAsync)}' method."));
 
            // We "delivered" the stream item successfully as far as the caller cares
            return new ValueTask<bool>(true);
        }
 
        protected override void Cancel()
        {
            if (TryBeginStopActivity())
            {
                Activity.SetStatus(ActivityStatusCode.Error);
                Activity.SetTag("error.type", typeof(OperationCanceledException).FullName);
                Activity.Stop();
            }
 
            _completionSource.TrySetCanceled();
        }
    }
 
    private static partial class Log
    {
        // Category: Streaming and NonStreaming
 
        [LoggerMessage(1, LogLevel.Trace, "Invocation {InvocationId} created.", EventName = "InvocationCreated")]
        public static partial void InvocationCreated(ILogger logger, string invocationId);
 
        [LoggerMessage(2, LogLevel.Trace, "Invocation {InvocationId} disposed.", EventName = "InvocationDisposed")]
        public static partial void InvocationDisposed(ILogger logger, string invocationId);
 
        [LoggerMessage(3, LogLevel.Trace, "Invocation {InvocationId} marked as completed.", EventName = "InvocationCompleted")]
        public static partial void InvocationCompleted(ILogger logger, string invocationId);
 
        [LoggerMessage(4, LogLevel.Trace, "Invocation {InvocationId} marked as failed.", EventName = "InvocationFailed")]
        public static partial void InvocationFailed(ILogger logger, string invocationId);
 
        // Category: Streaming
 
        [LoggerMessage(5, LogLevel.Error, "Invocation {InvocationId} caused an error trying to write a stream item.", EventName = "ErrorWritingStreamItem")]
        public static partial void ErrorWritingStreamItem(ILogger logger, string invocationId, Exception exception);
 
        [LoggerMessage(6, LogLevel.Error, "Invocation {InvocationId} received a completion result, but was invoked as a streaming invocation.", EventName = "ReceivedUnexpectedComplete")]
        public static partial void ReceivedUnexpectedComplete(ILogger logger, string invocationId);
 
        // Category: NonStreaming
 
        [LoggerMessage(7, LogLevel.Error, "Invocation {InvocationId} received stream item but was invoked as a non-streamed invocation.", EventName = "StreamItemOnNonStreamInvocation")]
        public static partial void StreamItemOnNonStreamInvocation(ILogger logger, string invocationId);
    }
}