File: Protocol\MessagePackHubProtocolWorker.cs
Web Access
Project: src\src\SignalR\common\Protocols.MessagePack\src\Microsoft.AspNetCore.SignalR.Protocols.MessagePack.csproj (Microsoft.AspNetCore.SignalR.Protocols.MessagePack)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#pragma warning disable IDE0005 // This file is shared across multiple projects making it ugly to ignore unused usings
 
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Runtime.ExceptionServices;
using System.Text;
using MessagePack;
using Microsoft.AspNetCore.Internal;
 
namespace Microsoft.AspNetCore.SignalR.Protocol;
 
/// <summary>
/// Implements support for MessagePackHubProtocol. This code is shared between SignalR and Blazor.
/// </summary>
internal abstract class MessagePackHubProtocolWorker
{
    private const int ErrorResult = 1;
    private const int VoidResult = 2;
    private const int NonVoidResult = 3;
 
    public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, [NotNullWhen(true)] out HubMessage? message)
    {
        if (!BinaryMessageParser.TryParseMessage(ref input, out var payload))
        {
            message = null;
            return false;
        }
 
        var reader = new MessagePackReader(payload);
        message = ParseMessage(ref reader, binder);
        return message != null;
    }
 
    private HubMessage? ParseMessage(ref MessagePackReader reader, IInvocationBinder binder)
    {
        var itemCount = reader.ReadArrayHeader();
 
        var messageType = ReadInt32(ref reader, "messageType");
 
        switch (messageType)
        {
            case HubProtocolConstants.InvocationMessageType:
                return CreateInvocationMessage(ref reader, binder, itemCount);
            case HubProtocolConstants.StreamInvocationMessageType:
                return CreateStreamInvocationMessage(ref reader, binder, itemCount);
            case HubProtocolConstants.StreamItemMessageType:
                return CreateStreamItemMessage(ref reader, binder);
            case HubProtocolConstants.CompletionMessageType:
                return CreateCompletionMessage(ref reader, binder);
            case HubProtocolConstants.CancelInvocationMessageType:
                return CreateCancelInvocationMessage(ref reader);
            case HubProtocolConstants.PingMessageType:
                return PingMessage.Instance;
            case HubProtocolConstants.CloseMessageType:
                return CreateCloseMessage(ref reader, itemCount);
            case HubProtocolConstants.AckMessageType:
                return CreateAckMessage(ref reader);
            case HubProtocolConstants.SequenceMessageType:
                return CreateSequenceMessage(ref reader);
            default:
                // Future protocol changes can add message types, old clients can ignore them
                return null;
        }
    }
 
    private HubMessage CreateInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, int itemCount)
    {
        var headers = ReadHeaders(ref reader);
        var invocationId = ReadInvocationId(ref reader);
 
        // For MsgPack, we represent an empty invocation ID as an empty string,
        // so we need to normalize that to "null", which is what indicates a non-blocking invocation.
        if (string.IsNullOrEmpty(invocationId))
        {
            invocationId = null;
        }
 
        var target = ReadString(ref reader, binder, "target");
        ThrowIfNullOrEmpty(target, "target for Invocation message");
 
        object?[]? arguments;
        try
        {
            var parameterTypes = binder.GetParameterTypes(target);
            arguments = BindArguments(ref reader, parameterTypes);
        }
        catch (Exception ex)
        {
            return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
        }
 
        string[]? streams = null;
        // Previous clients will send 5 items, so we check if they sent a stream array or not
        if (itemCount > 5)
        {
            streams = ReadStreamIds(ref reader);
        }
 
        return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams));
    }
 
    private HubMessage CreateStreamInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, int itemCount)
    {
        var headers = ReadHeaders(ref reader);
        var invocationId = ReadInvocationId(ref reader);
        ThrowIfNullOrEmpty(invocationId, "invocation ID for StreamInvocation message");
 
        var target = ReadString(ref reader, "target");
        ThrowIfNullOrEmpty(target, "target for StreamInvocation message");
 
        object?[] arguments;
        try
        {
            var parameterTypes = binder.GetParameterTypes(target);
            arguments = BindArguments(ref reader, parameterTypes);
        }
        catch (Exception ex)
        {
            return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
        }
 
        string[]? streams = null;
        // Previous clients will send 5 items, so we check if they sent a stream array or not
        if (itemCount > 5)
        {
            streams = ReadStreamIds(ref reader);
        }
 
        return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
    }
 
    private HubMessage CreateStreamItemMessage(ref MessagePackReader reader, IInvocationBinder binder)
    {
        var headers = ReadHeaders(ref reader);
        var invocationId = ReadInvocationId(ref reader);
        ThrowIfNullOrEmpty(invocationId, "invocation ID for StreamItem message");
 
        object? value;
        try
        {
            var itemType = binder.GetStreamItemType(invocationId);
            value = DeserializeObject(ref reader, itemType, "item");
        }
        catch (Exception ex)
        {
            return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex));
        }
 
        return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
    }
 
    private CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, IInvocationBinder binder)
    {
        var headers = ReadHeaders(ref reader);
        var invocationId = ReadInvocationId(ref reader);
        ThrowIfNullOrEmpty(invocationId, "invocation ID for Completion message");
 
        var resultKind = ReadInt32(ref reader, "resultKind");
 
        string? error = null;
        object? result = null;
        var hasResult = false;
 
        switch (resultKind)
        {
            case ErrorResult:
                error = ReadString(ref reader, "error");
                break;
            case NonVoidResult:
                hasResult = true;
                var itemType = ProtocolHelper.TryGetReturnType(binder, invocationId);
                if (itemType is null)
                {
                    reader.Skip();
                }
                else
                {
                    if (itemType == typeof(RawResult))
                    {
                        result = new RawResult(reader.ReadRaw());
                    }
                    else
                    {
                        try
                        {
                            result = DeserializeObject(ref reader, itemType, "argument");
                        }
                        catch (Exception ex)
                        {
                            error = $"Error trying to deserialize result to {itemType.Name}. {ex.Message}";
                            hasResult = false;
                        }
                    }
                }
                break;
            case VoidResult:
                hasResult = false;
                break;
            default:
                throw new InvalidDataException("Invalid invocation result kind.");
        }
 
        return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult));
    }
 
    private static CancelInvocationMessage CreateCancelInvocationMessage(ref MessagePackReader reader)
    {
        var headers = ReadHeaders(ref reader);
        var invocationId = ReadInvocationId(ref reader);
        ThrowIfNullOrEmpty(invocationId, "invocation ID for CancelInvocation message");
 
        return ApplyHeaders(headers, new CancelInvocationMessage(invocationId));
    }
 
    private static CloseMessage CreateCloseMessage(ref MessagePackReader reader, int itemCount)
    {
        var error = ReadString(ref reader, "error");
        var allowReconnect = false;
 
        if (itemCount > 2)
        {
            allowReconnect = ReadBoolean(ref reader, "allowReconnect");
        }
 
        // An empty string is still an error
        if (error == null && !allowReconnect)
        {
            return CloseMessage.Empty;
        }
 
        return new CloseMessage(error, allowReconnect);
    }
 
    private static Dictionary<string, string>? ReadHeaders(ref MessagePackReader reader)
    {
        var headerCount = ReadMapLength(ref reader, "headers");
        if (headerCount > 0)
        {
            var headers = new Dictionary<string, string>(StringComparer.Ordinal);
 
            for (var i = 0; i < headerCount; i++)
            {
                var key = ReadString(ref reader, $"headers[{i}].Key");
                ThrowIfNullOrEmpty(key, "key in header");
 
                var value = ReadString(ref reader, $"headers[{i}].Value");
                ThrowIfNullOrEmpty(value, "value in header");
 
                headers.Add(key, value);
            }
            return headers;
        }
        else
        {
            return null;
        }
    }
 
    private static string[]? ReadStreamIds(ref MessagePackReader reader)
    {
        var streamIdCount = ReadArrayLength(ref reader, "streamIds");
        List<string>? streams = null;
 
        if (streamIdCount > 0)
        {
            streams = new List<string>();
            for (var i = 0; i < streamIdCount; i++)
            {
                var id = reader.ReadString();
                ThrowIfNullOrEmpty(id, "value in streamIds received");
 
                streams.Add(id);
            }
        }
 
        return streams?.ToArray();
    }
 
    private static AckMessage CreateAckMessage(ref MessagePackReader reader)
    {
        return new AckMessage(ReadInt64(ref reader, "sequenceId"));
    }
 
    private static SequenceMessage CreateSequenceMessage(ref MessagePackReader reader)
    {
        return new SequenceMessage(ReadInt64(ref reader, "sequenceId"));
    }
 
    private object?[] BindArguments(ref MessagePackReader reader, IReadOnlyList<Type> parameterTypes)
    {
        var argumentCount = ReadArrayLength(ref reader, "arguments");
 
        if (parameterTypes.Count != argumentCount)
        {
            throw new InvalidDataException(
                $"Invocation provides {argumentCount} argument(s) but target expects {parameterTypes.Count}.");
        }
 
        try
        {
            var arguments = new object?[argumentCount];
            for (var i = 0; i < argumentCount; i++)
            {
                arguments[i] = DeserializeObject(ref reader, parameterTypes[i], "argument");
            }
 
            return arguments;
        }
        catch (Exception ex)
        {
            throw new InvalidDataException("Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.", ex);
        }
    }
 
    protected abstract object? DeserializeObject(ref MessagePackReader reader, Type type, string field);
 
    private static T ApplyHeaders<T>(IDictionary<string, string>? source, T destination) where T : HubInvocationMessage
    {
        if (source != null && source.Count > 0)
        {
            destination.Headers = source;
        }
 
        return destination;
    }
 
    /// <inheritdoc />
    public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
    {
        var memoryBufferWriter = MemoryBufferWriter.Get();
 
        try
        {
            var writer = new MessagePackWriter(memoryBufferWriter);
 
            // Write message to a buffer so we can get its length
            WriteMessageCore(message, ref writer);
 
            // Write length then message to output
            BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output);
            memoryBufferWriter.CopyTo(output);
        }
        finally
        {
            MemoryBufferWriter.Return(memoryBufferWriter);
        }
    }
 
    /// <inheritdoc />
    public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
    {
        var memoryBufferWriter = MemoryBufferWriter.Get();
 
        try
        {
            var writer = new MessagePackWriter(memoryBufferWriter);
 
            // Write message to a buffer so we can get its length
            WriteMessageCore(message, ref writer);
 
            var dataLength = memoryBufferWriter.Length;
            var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length);
 
            var array = new byte[dataLength + prefixLength];
            var span = array.AsSpan();
 
            // Write length then message to output
            var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span);
            Debug.Assert(written == prefixLength);
            memoryBufferWriter.CopyTo(span.Slice(prefixLength));
 
            return array;
        }
        finally
        {
            MemoryBufferWriter.Return(memoryBufferWriter);
        }
    }
 
    private void WriteMessageCore(HubMessage message, ref MessagePackWriter writer)
    {
        switch (message)
        {
            case InvocationMessage invocationMessage:
                WriteInvocationMessage(invocationMessage, ref writer);
                break;
            case StreamInvocationMessage streamInvocationMessage:
                WriteStreamInvocationMessage(streamInvocationMessage, ref writer);
                break;
            case StreamItemMessage streamItemMessage:
                WriteStreamingItemMessage(streamItemMessage, ref writer);
                break;
            case CompletionMessage completionMessage:
                WriteCompletionMessage(completionMessage, ref writer);
                break;
            case CancelInvocationMessage cancelInvocationMessage:
                WriteCancelInvocationMessage(cancelInvocationMessage, ref writer);
                break;
            case PingMessage:
                WritePingMessage(ref writer);
                break;
            case CloseMessage closeMessage:
                WriteCloseMessage(closeMessage, ref writer);
                break;
            case AckMessage ackMessage:
                WriteAckMessage(ackMessage, ref writer);
                break;
            case SequenceMessage sequenceMessage:
                WriteSequenceMessage(sequenceMessage, ref writer);
                break;
            default:
                throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
        }
 
        writer.Flush();
    }
 
    private void WriteInvocationMessage(InvocationMessage message, ref MessagePackWriter writer)
    {
        writer.WriteArrayHeader(6);
 
        writer.Write(HubProtocolConstants.InvocationMessageType);
        PackHeaders(message.Headers, ref writer);
        if (string.IsNullOrEmpty(message.InvocationId))
        {
            writer.WriteNil();
        }
        else
        {
            writer.Write(message.InvocationId);
        }
        writer.Write(message.Target);
 
        if (message.Arguments is null)
        {
            writer.WriteArrayHeader(0);
        }
        else
        {
            writer.WriteArrayHeader(message.Arguments.Length);
            foreach (var arg in message.Arguments)
            {
                WriteArgument(arg, ref writer);
            }
        }
 
        WriteStreamIds(message.StreamIds, ref writer);
    }
 
    private void WriteStreamInvocationMessage(StreamInvocationMessage message, ref MessagePackWriter writer)
    {
        writer.WriteArrayHeader(6);
 
        writer.Write(HubProtocolConstants.StreamInvocationMessageType);
        PackHeaders(message.Headers, ref writer);
        writer.Write(message.InvocationId);
        writer.Write(message.Target);
 
        if (message.Arguments is null)
        {
            writer.WriteArrayHeader(0);
        }
        else
        {
            writer.WriteArrayHeader(message.Arguments.Length);
            foreach (var arg in message.Arguments)
            {
                WriteArgument(arg, ref writer);
            }
        }
 
        WriteStreamIds(message.StreamIds, ref writer);
    }
 
    private void WriteStreamingItemMessage(StreamItemMessage message, ref MessagePackWriter writer)
    {
        writer.WriteArrayHeader(4);
        writer.Write(HubProtocolConstants.StreamItemMessageType);
        PackHeaders(message.Headers, ref writer);
        writer.Write(message.InvocationId);
        WriteArgument(message.Item, ref writer);
    }
 
    private void WriteArgument(object? argument, ref MessagePackWriter writer)
    {
        if (argument == null)
        {
            writer.WriteNil();
        }
        else if (argument is RawResult result)
        {
            writer.WriteRaw(result.RawSerializedData);
        }
        else
        {
            Serialize(ref writer, argument.GetType(), argument);
        }
    }
 
    protected abstract void Serialize(ref MessagePackWriter writer, Type type, object value);
 
    private static void WriteStreamIds(string[]? streamIds, ref MessagePackWriter writer)
    {
        if (streamIds != null)
        {
            writer.WriteArrayHeader(streamIds.Length);
            foreach (var streamId in streamIds)
            {
                writer.Write(streamId);
            }
        }
        else
        {
            writer.WriteArrayHeader(0);
        }
    }
 
    private void WriteCompletionMessage(CompletionMessage message, ref MessagePackWriter writer)
    {
        var resultKind =
            message.Error != null ? ErrorResult :
            message.HasResult ? NonVoidResult :
            VoidResult;
 
        writer.WriteArrayHeader(4 + (resultKind != VoidResult ? 1 : 0));
        writer.Write(HubProtocolConstants.CompletionMessageType);
        PackHeaders(message.Headers, ref writer);
        writer.Write(message.InvocationId);
        writer.Write(resultKind);
        switch (resultKind)
        {
            case ErrorResult:
                writer.Write(message.Error);
                break;
            case NonVoidResult:
                WriteArgument(message.Result, ref writer);
                break;
        }
    }
 
    private static void WriteCancelInvocationMessage(CancelInvocationMessage message, ref MessagePackWriter writer)
    {
        writer.WriteArrayHeader(3);
        writer.Write(HubProtocolConstants.CancelInvocationMessageType);
        PackHeaders(message.Headers, ref writer);
        writer.Write(message.InvocationId);
    }
 
    private static void WriteCloseMessage(CloseMessage message, ref MessagePackWriter writer)
    {
        writer.WriteArrayHeader(3);
        writer.Write(HubProtocolConstants.CloseMessageType);
        if (string.IsNullOrEmpty(message.Error))
        {
            writer.WriteNil();
        }
        else
        {
            writer.Write(message.Error);
        }
 
        writer.Write(message.AllowReconnect);
    }
 
    private static void WritePingMessage(ref MessagePackWriter writer)
    {
        writer.WriteArrayHeader(1);
        writer.Write(HubProtocolConstants.PingMessageType);
    }
 
    private static void WriteAckMessage(AckMessage message, ref MessagePackWriter writer)
    {
        writer.WriteArrayHeader(2);
        writer.Write(HubProtocolConstants.AckMessageType);
        writer.Write(message.SequenceId);
    }
 
    private static void WriteSequenceMessage(SequenceMessage message, ref MessagePackWriter writer)
    {
        writer.WriteArrayHeader(2);
        writer.Write(HubProtocolConstants.SequenceMessageType);
        writer.Write(message.SequenceId);
    }
 
    private static void PackHeaders(IDictionary<string, string>? headers, ref MessagePackWriter writer)
    {
        if (headers != null)
        {
            writer.WriteMapHeader(headers.Count);
            if (headers.Count > 0)
            {
                foreach (var header in headers)
                {
                    writer.Write(header.Key);
                    writer.Write(header.Value);
                }
            }
        }
        else
        {
            writer.WriteMapHeader(0);
        }
    }
 
    private static string? ReadInvocationId(ref MessagePackReader reader) =>
        ReadString(ref reader, "invocationId");
 
    private static bool ReadBoolean(ref MessagePackReader reader, string field)
    {
        try
        {
            return reader.ReadBoolean();
        }
        catch (Exception ex)
        {
            throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex);
        }
    }
 
    private static int ReadInt32(ref MessagePackReader reader, string field)
    {
        try
        {
            return reader.ReadInt32();
        }
        catch (Exception ex)
        {
            throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex);
        }
    }
 
    private static long ReadInt64(ref MessagePackReader reader, string field)
    {
        try
        {
            return reader.ReadInt64();
        }
        catch (Exception ex)
        {
            throw new InvalidDataException($"Reading '{field}' as Int64 failed.", ex);
        }
    }
 
    protected static string? ReadString(ref MessagePackReader reader, IInvocationBinder binder, string field)
    {
        try
        {
#if NETCOREAPP
            if (reader.TryReadStringSpan(out var span))
            {
                return binder.GetTarget(span) ?? Encoding.UTF8.GetString(span);
            }
            return reader.ReadString();
#else
            return reader.ReadString();
#endif
        }
        catch (Exception ex)
        {
            throw new InvalidDataException($"Reading '{field}' as String failed.", ex);
        }
    }
 
    protected static string? ReadString(ref MessagePackReader reader, string field)
    {
        try
        {
            return reader.ReadString();
        }
        catch (Exception ex)
        {
            throw new InvalidDataException($"Reading '{field}' as String failed.", ex);
        }
    }
 
    private static long ReadMapLength(ref MessagePackReader reader, string field)
    {
        try
        {
            return reader.ReadMapHeader();
        }
        catch (Exception ex)
        {
            throw new InvalidDataException($"Reading map length for '{field}' failed.", ex);
        }
    }
 
    private static long ReadArrayLength(ref MessagePackReader reader, string field)
    {
        try
        {
            return reader.ReadArrayHeader();
        }
        catch (Exception ex)
        {
            throw new InvalidDataException($"Reading array length for '{field}' failed.", ex);
        }
    }
 
    private static void ThrowIfNullOrEmpty([NotNull] string? target, string message)
    {
        if (string.IsNullOrEmpty(target))
        {
            throw new InvalidDataException($"Null or empty {message}.");
        }
    }
}