File: NativeAotTests.cs
Web Access
Project: src\src\SignalR\server\SignalR\test\Microsoft.AspNetCore.SignalR.Tests\Microsoft.AspNetCore.SignalR.Tests.csproj (Microsoft.AspNetCore.SignalR.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Globalization;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Channels;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.DotNet.RemoteExecutor;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
 
namespace Microsoft.AspNetCore.SignalR.Tests;
 
public partial class NativeAotTests : FunctionalTestBase
{
    [ConditionalFact]
    [RemoteExecutionSupported]
    public void CanCallAsyncMethods()
    {
        RunNativeAotTest(static async () =>
        {
            //System.Diagnostics.Debugger.Launch();
            var loggerFactory = new StringLoggerFactory();
            await using (var server = await InProcessTestServer<Startup<AsyncMethodHub>>.StartServer(loggerFactory))
            {
                var hubConnectionBuilder = new HubConnectionBuilder()
                    .WithUrl(server.Url + "/hub");
                AppJsonSerializerContext.AddToJsonHubProtocol(hubConnectionBuilder.Services);
                var connection = hubConnectionBuilder.Build();
 
                await connection.StartAsync().DefaultTimeout();
 
                await connection.InvokeAsync(nameof(AsyncMethodHub.TaskMethod)).DefaultTimeout();
                Assert.Contains("TaskMethod called", loggerFactory.ToString());
 
                await connection.InvokeAsync(nameof(AsyncMethodHub.ValueTaskMethod)).DefaultTimeout();
                Assert.Contains("ValueTaskMethod called", loggerFactory.ToString());
 
                await connection.InvokeAsync(nameof(AsyncMethodHub.CustomTaskMethod)).DefaultTimeout();
                Assert.Contains("CustomTaskMethod called", loggerFactory.ToString());
 
                var result = await connection.InvokeAsync<int>(nameof(AsyncMethodHub.TaskValueMethod)).DefaultTimeout();
                Assert.Equal(42, result);
 
                result = await connection.InvokeAsync<int>(nameof(AsyncMethodHub.ValueTaskValueMethod)).DefaultTimeout();
                Assert.Equal(43, result);
 
                result = await connection.InvokeAsync<int>(nameof(AsyncMethodHub.CustomTaskValueMethod)).DefaultTimeout();
                Assert.Equal(44, result);
 
                var counterResults = new List<string>();
                await foreach (var item in connection.StreamAsync<string>(nameof(AsyncMethodHub.CounterAsyncEnumerable), 4))
                {
                    counterResults.Add(item);
                }
                Assert.Equal(["0", "1", "2", "3"], counterResults);
 
                counterResults.Clear();
                await foreach (var item in connection.StreamAsync<string>(nameof(AsyncMethodHub.CounterAsyncEnumerableImpl), 5))
                {
                    counterResults.Add(item);
                }
                Assert.Equal(["0", "1", "2", "3", "4"], counterResults);
 
                var echoResults = new List<string>();
                var asyncEnumerable = connection.StreamAsync<string>(nameof(AsyncMethodHub.StreamEchoAsyncEnumerable), StreamMessages());
                await foreach (var item in asyncEnumerable)
                {
                    echoResults.Add(item);
                }
                Assert.Equal(["echo:message one", "echo:message two"], echoResults);
 
                echoResults.Clear();
                var channel = Channel.CreateBounded<string>(10);
                var echoResponseReader = await connection.StreamAsChannelAsync<string>(nameof(AsyncMethodHub.StreamEcho), channel.Reader);
                await channel.Writer.WriteAsync("some data");
                await channel.Writer.WriteAsync("some more data");
                await channel.Writer.WriteAsync("even more data");
                channel.Writer.Complete();
 
                await foreach (var item in echoResponseReader.ReadAllAsync())
                {
                    echoResults.Add(item);
                }
                Assert.Equal(["echo:some data", "echo:some more data", "echo:even more data"], echoResults);
 
                var streamValueTypeResults = new List<int>();
                await foreach (var item in connection.StreamAsync<int>(nameof(AsyncMethodHub.ReturnEnumerableValueType)))
                {
                    streamValueTypeResults.Add(item);
                }
                Assert.Equal([1, 2], streamValueTypeResults);
 
                var returnChannelValueTypeResults = new List<char>();
                var returnChannelValueTypeReader = await connection.StreamAsChannelAsync<char>(nameof(AsyncMethodHub.ReturnChannelValueType), "Hello");
                await foreach (var item in returnChannelValueTypeReader.ReadAllAsync())
                {
                    returnChannelValueTypeResults.Add(item);
                }
                Assert.Equal(['H', 'e', 'l', 'l', 'o'], returnChannelValueTypeResults);
 
                // Even though SignalR server doesn't support Hub methods with streaming value types in native AOT (https://github.com/dotnet/aspnetcore/issues/56179),
                // still test that the client can send them.
                var stringResult = await connection.InvokeAsync<string>(nameof(AsyncMethodHub.EnumerableIntParameter), StreamInts());
                Assert.Equal("1, 2, 3", stringResult);
 
                var channelShorts = Channel.CreateBounded<short>(10);
                await channelShorts.Writer.WriteAsync(9);
                await channelShorts.Writer.WriteAsync(8);
                await channelShorts.Writer.WriteAsync(7);
                channelShorts.Writer.Complete();
 
                stringResult = await connection.InvokeAsync<string>(nameof(AsyncMethodHub.ChannelShortParameter), channelShorts.Reader);
                Assert.Equal("9, 8, 7", stringResult);
            }
        });
    }
 
    private static async IAsyncEnumerable<string> StreamMessages()
    {
        await Task.Yield();
        yield return "message one";
        await Task.Yield();
        yield return "message two";
    }
 
    private static async IAsyncEnumerable<int> StreamInts()
    {
        await Task.Yield();
        yield return 1;
        await Task.Yield();
        yield return 2;
        await Task.Yield();
        yield return 3;
    }
 
    [ConditionalFact]
    [RemoteExecutionSupported]
    public void UsingValueTypesInStreamingThrows()
    {
        RunNativeAotTest(static async () =>
        {
            var e = await Assert.ThrowsAsync<InvalidOperationException>(() => InProcessTestServer<Startup<ChannelValueTypeMethodHub>>.StartServer(NullLoggerFactory.Instance));
            Assert.Contains("Method 'Microsoft.AspNetCore.SignalR.Tests.NativeAotTests+ChannelValueTypeMethodHub.StreamValueType' is not supported with native AOT because it has a parameter of type 'System.Threading.Channels.ChannelReader`1[System.Double]'.", e.Message);
        });
 
        RunNativeAotTest(static async () =>
        {
            var e = await Assert.ThrowsAsync<InvalidOperationException>(() => InProcessTestServer<Startup<EnumerableValueTypeMethodHub>>.StartServer(NullLoggerFactory.Instance));
            Assert.Contains("Method 'Microsoft.AspNetCore.SignalR.Tests.NativeAotTests+EnumerableValueTypeMethodHub.StreamValueType' is not supported with native AOT because it has a parameter of type 'System.Collections.Generic.IAsyncEnumerable`1[System.Single]'.", e.Message);
        });
    }
 
    private static void RunNativeAotTest(Func<Task> test)
    {
        var options = new RemoteInvokeOptions();
        options.RuntimeConfigurationOptions.Add("System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported", "false");
        options.RuntimeConfigurationOptions.Add("Microsoft.AspNetCore.SignalR.Hub.IsCustomAwaitableSupported", "false");
        options.RuntimeConfigurationOptions.Add("System.Text.Json.JsonSerializer.IsReflectionEnabledByDefault", "false");
 
        using var remoteHandle = RemoteExecutor.Invoke(test, options);
    }
 
    public class Startup<THub> where THub : Hub
    {
        public void ConfigureServices(IServiceCollection services)
        {
            services.AddConnections();
            services.AddSignalR(options =>
            {
                options.EnableDetailedErrors = true;
            });
            AppJsonSerializerContext.AddToJsonHubProtocol(services);
        }
 
        public void Configure(IApplicationBuilder app)
        {
            app.UseRouting();
 
            app.UseEndpoints(endpoints =>
            {
                endpoints.MapHub<THub>("/hub");
            });
        }
    }
 
    public class AsyncMethodHub : TestHub
    {
        public async Task TaskMethod(ILogger<AsyncMethodHub> logger)
        {
            await Task.Yield(); // ensure the returned Task gets awaited correctly
            logger.LogInformation("TaskMethod called");
        }
 
        public async ValueTask ValueTaskMethod(ILogger<AsyncMethodHub> logger)
        {
            await Task.Yield(); // ensure the returned Task gets awaited correctly
            logger.LogInformation("ValueTaskMethod called");
        }
 
        public TaskDerivedType CustomTaskMethod(ILogger<AsyncMethodHub> logger)
        {
            var task = new TaskDerivedType();
            task.Start();
            logger.LogInformation("CustomTaskMethod called");
            return task;
        }
 
        public async Task<int> TaskValueMethod()
        {
            await Task.Yield(); // ensure the returned Task gets awaited correctly
            return 42;
        }
 
        public async ValueTask<int> ValueTaskValueMethod()
        {
            await Task.Yield(); // ensure the returned Task gets awaited correctly
            return 43;
        }
 
        public TaskOfTDerivedType<int> CustomTaskValueMethod()
        {
            var task = new TaskOfTDerivedType<int>(44);
            task.Start();
            return task;
        }
 
        public async IAsyncEnumerable<string> CounterAsyncEnumerable(int count)
        {
            for (int i = 0; i < count; i++)
            {
                await Task.Yield();
                yield return i.ToString(CultureInfo.InvariantCulture);
            }
        }
 
        public StreamingHub.AsyncEnumerableImpl<string> CounterAsyncEnumerableImpl(int count)
        {
            return new StreamingHub.AsyncEnumerableImpl<string>(CounterAsyncEnumerable(count));
        }
 
        public ChannelReader<string> StreamEcho(ChannelReader<string> source)
        {
            Channel<string> output = Channel.CreateUnbounded<string>();
 
            _ = Task.Run(async () =>
            {
                await foreach (var item in source.ReadAllAsync())
                {
                    await output.Writer.WriteAsync("echo:" + item);
                }
 
                output.Writer.TryComplete();
            });
 
            return output.Reader;
        }
 
        public async IAsyncEnumerable<string> StreamEchoAsyncEnumerable(IAsyncEnumerable<string> source)
        {
            await foreach (var item in source)
            {
                yield return "echo:" + item;
            }
        }
 
        public async IAsyncEnumerable<int> ReturnEnumerableValueType()
        {
            await Task.Yield();
            yield return 1;
            await Task.Yield();
            yield return 2;
        }
 
        public ChannelReader<char> ReturnChannelValueType(string source)
        {
            Channel<char> output = Channel.CreateUnbounded<char>();
 
            _ = Task.Run(async () =>
            {
                foreach (var item in source)
                {
                    await Task.Yield();
                    await output.Writer.WriteAsync(item);
                }
 
                output.Writer.TryComplete();
            });
 
            return output.Reader;
        }
 
        // using 'object' as the streaming parameter type because streaming ValueTypes is not supported on the server
        public async Task<string> EnumerableIntParameter(IAsyncEnumerable<object> source)
        {
            var result = new StringBuilder();
            var first = true;
            // These get deserialized as JsonElement since the streaming parameter is 'object'
            await foreach (JsonElement item in source)
            {
                if (first)
                {
                    first = false;
                }
                else
                {
                    result.Append(", ");
                }
 
                result.Append(item.GetInt32());
            }
            return result.ToString();
        }
 
        // using 'object' as the streaming parameter type because streaming ValueTypes is not supported on the server
        public async Task<string> ChannelShortParameter(ChannelReader<object> source)
        {
            var result = new StringBuilder();
            var first = true;
            // These get deserialized as JsonElement since the streaming parameter is 'object'
            await foreach (JsonElement item in source.ReadAllAsync())
            {
                if (first)
                {
                    first = false;
                }
                else
                {
                    result.Append(", ");
                }
 
                result.Append(item.GetInt16());
            }
            return result.ToString();
        }
    }
 
    public class ChannelValueTypeMethodHub : TestHub
    {
        public async Task StreamValueType(ILogger<ChannelValueTypeMethodHub> logger, ChannelReader<double> source)
        {
            await foreach (var item in source.ReadAllAsync())
            {
                logger.LogInformation("Received: {item}", item);
            }
        }
    }
 
    public class EnumerableValueTypeMethodHub : TestHub
    {
        public async Task StreamValueType(ILogger<EnumerableValueTypeMethodHub> logger, IAsyncEnumerable<float> source)
        {
            await foreach (var item in source)
            {
                logger.LogInformation("Received: {item}", item);
            }
        }
    }
 
    public class TaskDerivedType : Task
    {
        public TaskDerivedType()
            : base(() => { })
        {
        }
    }
 
    public class TaskOfTDerivedType<T> : Task<T>
    {
        public TaskOfTDerivedType(T input)
            : base(() => input)
        {
        }
    }
 
    public class StringLoggerFactory : ILoggerFactory
    {
        private readonly StringBuilder _log = new StringBuilder();
 
        public void AddProvider(ILoggerProvider provider) { }
 
        public ILogger CreateLogger(string name)
        {
            return new StringLogger(name, this);
        }
 
        public void Dispose() { }
 
        public override string ToString()
        {
            return _log.ToString();
        }
 
        private sealed class StringLogger : ILogger
        {
            private readonly StringLoggerFactory _factory;
            private readonly string _name;
 
            public StringLogger(string name, StringLoggerFactory factory)
            {
                _name = name;
                _factory = factory;
            }
 
            public IDisposable BeginScope<TState>(TState state)
            {
                return new DummyDisposable();
            }
 
            public bool IsEnabled(LogLevel logLevel) => true;
 
            public void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func<TState, Exception, string> formatter)
            {
                string message = string.Format(CultureInfo.InvariantCulture,
                    "Provider: {0}" + Environment.NewLine +
                    "Log level: {1}" + Environment.NewLine +
                    "Event id: {2}" + Environment.NewLine +
                    "Exception: {3}" + Environment.NewLine +
                    "Message: {4}", _name, logLevel, eventId, exception?.ToString(), formatter(state, exception));
                _factory._log.AppendLine(message);
            }
 
            private sealed class DummyDisposable : IDisposable
            {
                public void Dispose()
                {
                    // no-op
                }
            }
        }
    }
 
    [JsonSerializable(typeof(object))]
    [JsonSerializable(typeof(string))]
    [JsonSerializable(typeof(int))]
    [JsonSerializable(typeof(short))]
    [JsonSerializable(typeof(char))]
    internal partial class AppJsonSerializerContext : JsonSerializerContext
    {
        public static void AddToJsonHubProtocol(IServiceCollection services)
        {
            services.Configure<JsonHubProtocolOptions>(o =>
            {
                o.PayloadSerializerOptions.TypeInfoResolverChain.Insert(0, Default);
            });
        }
    }
}