File: HubProtocolVersionTests.cs
Web Access
Project: src\src\SignalR\clients\csharp\Client\test\FunctionalTests\Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj (Microsoft.AspNetCore.SignalR.Client.FunctionalTests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Net;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Client;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Microsoft.Extensions.Options;
using Newtonsoft.Json.Linq;
using Xunit;
 
namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests;
 
public class HubProtocolVersionTestsCollection : ICollectionFixture<InProcessTestServer<VersionStartup>>
{
    public const string Name = nameof(HubProtocolVersionTestsCollection);
}
 
[Collection(HubProtocolVersionTestsCollection.Name)]
public class HubProtocolVersionTests : FunctionalTestBase
{
    [Theory]
    [MemberData(nameof(TransportTypes))]
    public async Task ClientUsingOldCallWithOriginalProtocol(HttpTransportType transportType)
    {
        await using (var server = await StartServer<VersionStartup>())
        {
            var connectionBuilder = new HubConnectionBuilder()
                .WithLoggerFactory(LoggerFactory)
                .WithUrl(server.Url + "/version", transportType);
 
            var connection = connectionBuilder.Build();
 
            try
            {
                await connection.StartAsync().DefaultTimeout();
 
                var result = await connection.InvokeAsync<string>(nameof(VersionHub.Echo), "Hello World!").DefaultTimeout();
 
                Assert.Equal("Hello World!", result);
            }
            catch (Exception ex)
            {
                LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
                throw;
            }
            finally
            {
                await connection.DisposeAsync().DefaultTimeout();
            }
        }
    }
 
    [Theory]
    [MemberData(nameof(TransportTypes))]
    public async Task ClientUsingOldCallWithNewProtocol(HttpTransportType transportType)
    {
        await using (var server = await StartServer<VersionStartup>())
        {
            var connectionBuilder = new HubConnectionBuilder()
                .WithLoggerFactory(LoggerFactory)
                .WithUrl(server.Url + "/version", transportType);
            connectionBuilder.Services.AddSingleton<IHubProtocol>(new VersionedJsonHubProtocol(1000));
 
            var connection = connectionBuilder.Build();
 
            try
            {
                await connection.StartAsync().DefaultTimeout();
 
                var result = await connection.InvokeAsync<string>(nameof(VersionHub.Echo), "Hello World!").DefaultTimeout();
 
                Assert.Equal("Hello World!", result);
            }
            catch (Exception ex)
            {
                LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
                throw;
            }
            finally
            {
                await connection.DisposeAsync().DefaultTimeout();
            }
        }
    }
 
    [Theory]
    [MemberData(nameof(TransportTypes))]
    public async Task ClientUsingNewCallWithNewProtocol(HttpTransportType transportType)
    {
        await using (var server = await StartServer<VersionStartup>())
        {
            var httpConnectionFactory = new HttpConnectionFactory(
                Options.Create(new HttpConnectionOptions
                {
                    Transports = transportType,
                    DefaultTransferFormat = TransferFormat.Text
                }),
                LoggerFactory);
            var tcs = new TaskCompletionSource();
 
            var proxyConnectionFactory = new ProxyConnectionFactory(httpConnectionFactory);
 
            var connectionBuilder = new HubConnectionBuilder()
                .WithUrl(new Uri(server.Url + "/version"))
                .WithLoggerFactory(LoggerFactory);
            connectionBuilder.Services.AddSingleton<IHubProtocol>(new VersionedJsonHubProtocol(1000));
            connectionBuilder.Services.AddSingleton<IConnectionFactory>(proxyConnectionFactory);
 
            var connection = connectionBuilder.Build();
            connection.On("NewProtocolMethodClient", () =>
            {
                tcs.SetResult();
            });
 
            try
            {
                await connection.StartAsync().DefaultTimeout();
 
                // Task should already have been awaited in StartAsync
                var connectionContext = await proxyConnectionFactory.ConnectTask.DefaultTimeout();
 
                // Simulate a new call from the client
                var messageToken = new JObject
                {
                    ["type"] = int.MaxValue
                };
 
                connectionContext.Transport.Output.Write(Encoding.UTF8.GetBytes(messageToken.ToString()));
                connectionContext.Transport.Output.Write(new[] { (byte)0x1e });
                await connectionContext.Transport.Output.FlushAsync().DefaultTimeout();
 
                await tcs.Task.DefaultTimeout();
            }
            catch (Exception ex)
            {
                LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
                throw;
            }
            finally
            {
                await connection.DisposeAsync().DefaultTimeout();
            }
        }
    }
 
    [Theory]
    [MemberData(nameof(TransportTypes))]
    [LogLevel(LogLevel.Trace)]
    public async Task ClientWithUnsupportedProtocolVersionDoesNotConnect(HttpTransportType transportType)
    {
        bool ExpectedErrors(WriteContext writeContext)
        {
            return writeContext.LoggerName == typeof(HubConnection).FullName;
        }
 
        await using (var server = await StartServer<VersionStartup>(ExpectedErrors))
        {
            var connectionBuilder = new HubConnectionBuilder()
                .WithLoggerFactory(LoggerFactory)
                .WithUrl(server.Url + "/version", transportType);
            connectionBuilder.Services.AddSingleton<IHubProtocol>(new SingleVersionHubProtocol(new VersionedJsonHubProtocol(int.MaxValue), int.MaxValue));
 
            var connection = connectionBuilder.Build();
 
            try
            {
                await ExceptionAssert.ThrowsAsync<HubException>(
                    () => connection.StartAsync(),
                    "Unable to complete handshake with the server due to an error: The server does not support version 2147483647 of the 'json' protocol.").DefaultTimeout();
            }
            catch (Exception ex)
            {
                LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
                throw;
            }
            finally
            {
                await connection.DisposeAsync().DefaultTimeout();
            }
        }
    }
 
    public class SingleVersionHubProtocol : IHubProtocol
    {
        private readonly IHubProtocol _protocol;
        private readonly int _version;
 
        public SingleVersionHubProtocol(IHubProtocol inner, int version)
        {
            _protocol = inner;
            _version = version;
        }
 
        public string Name => _protocol.Name;
 
        public int Version => _version;
 
        public TransferFormat TransferFormat => _protocol.TransferFormat;
 
        public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message) => _protocol.GetMessageBytes(message);
 
        public bool IsVersionSupported(int version)
        {
            return version == _version;
        }
 
        public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, [NotNullWhen(true)] out HubMessage message)
        {
            return _protocol.TryParseMessage(ref input, binder, out message);
        }
 
        public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
        {
            _protocol.WriteMessage(message, output);
        }
    }
 
    private class ProxyConnectionFactory : IConnectionFactory
    {
        private readonly IConnectionFactory _innerFactory;
        public ValueTask<ConnectionContext> ConnectTask { get; private set; }
 
        public ProxyConnectionFactory(IConnectionFactory innerFactory)
        {
            _innerFactory = innerFactory;
        }
 
        public ValueTask<ConnectionContext> ConnectAsync(EndPoint endPoint, CancellationToken cancellationToken = default)
        {
            ConnectTask = _innerFactory.ConnectAsync(endPoint, cancellationToken);
            return ConnectTask;
        }
    }
 
    public static IEnumerable<object[]> TransportTypes()
    {
        if (TestHelpers.IsWebSocketsSupported())
        {
            yield return new object[] { HttpTransportType.WebSockets };
        }
        yield return new object[] { HttpTransportType.ServerSentEvents };
        yield return new object[] { HttpTransportType.LongPolling };
    }
}