File: MapConnectionHandlerTests.cs
Web Access
Project: src\src\SignalR\common\Http.Connections\test\Microsoft.AspNetCore.Http.Connections.Tests.csproj (Microsoft.AspNetCore.Http.Connections.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Net.WebSockets;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Cors;
using Microsoft.AspNetCore.Cors.Infrastructure;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Hosting.Server.Features;
using Microsoft.AspNetCore.Http.Timeouts;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Xunit.Abstractions;
 
namespace Microsoft.AspNetCore.Http.Connections.Tests;
 
public class MapConnectionHandlerTests
{
    private readonly ITestOutputHelper _output;
 
    public MapConnectionHandlerTests(ITestOutputHelper output)
    {
        _output = output;
    }
 
    [Fact]
    public void MapConnectionHandlerFindsMetadataPolicyOnEndPoint()
    {
        var authCount = 0;
        var policy1 = new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build();
        var req = new TestRequirement();
        using (var host = BuildWebHost<AuthConnectionHandler>("/auth",
            options => authCount += options.AuthorizationData.Count,
            endpoints => endpoints.RequireAuthorization(policy1).RequireAuthorization(pb => pb.Requirements.Add(req))))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/auth/negotiate", endpoint.DisplayName);
                    Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
                    var policies = endpoint.Metadata.GetOrderedMetadata<AuthorizationPolicy>();
                    Assert.Equal(2, policies.Count);
                    Assert.Equal(policy1, policies[0]);
                    Assert.Single(policies[1].Requirements);
                    Assert.Equal(req, policies[1].Requirements.First());
                },
                endpoint =>
                {
                    Assert.Equal("/auth", endpoint.DisplayName);
                    Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
                    var policies = endpoint.Metadata.GetOrderedMetadata<AuthorizationPolicy>();
                    Assert.Equal(2, policies.Count);
                    Assert.Equal(policy1, policies[0]);
                    Assert.Single(policies[1].Requirements);
                    Assert.Equal(req, policies[1].Requirements.First());
                });
        }
 
        Assert.Equal(0, authCount);
    }
 
    [Fact]
    public void MapConnectionHandlerFindsAuthAttributeOnEndPoint()
    {
        var authCount = 0;
        using (var host = BuildWebHost<AuthConnectionHandler>("/auth",
            options => authCount += options.AuthorizationData.Count))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/auth/negotiate", endpoint.DisplayName);
                    Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
                },
                endpoint =>
                {
                    Assert.Equal("/auth", endpoint.DisplayName);
                    Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
                });
        }
 
        Assert.Equal(0, authCount);
    }
 
    [Fact]
    public void MapConnectionHandlerFindsAuthAttributeOnInheritedEndPoint()
    {
        var authCount = 0;
        using (var host = BuildWebHost<InheritedAuthConnectionHandler>("/auth",
            options => authCount += options.AuthorizationData.Count))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/auth/negotiate", endpoint.DisplayName);
                    Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
                },
                endpoint =>
                {
                    Assert.Equal("/auth", endpoint.DisplayName);
                    Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
                });
        }
 
        Assert.Equal(0, authCount);
    }
 
    [Fact]
 
    public void MapConnectionHandlerFindsAuthAttributesOnDoubleAuthEndPoint()
    {
        var authCount = 0;
        using (var host = BuildWebHost<DoubleAuthConnectionHandler>("/auth",
            options => authCount += options.AuthorizationData.Count))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/auth/negotiate", endpoint.DisplayName);
                    Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
                },
                endpoint =>
                {
                    Assert.Equal("/auth", endpoint.DisplayName);
                    Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
                });
        }
 
        Assert.Equal(0, authCount);
    }
 
    [Fact]
    public void MapConnectionHandlerFindsAttributesFromEndPointAndOptions()
    {
        var authCount = 0;
        using (var host = BuildWebHost<AuthConnectionHandler>("/auth",
            options =>
            {
                authCount += options.AuthorizationData.Count;
                options.AuthorizationData.Add(new AuthorizeAttribute());
            }))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/auth/negotiate", endpoint.DisplayName);
                    Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
                },
                endpoint =>
                {
                    Assert.Equal("/auth", endpoint.DisplayName);
                    Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
                });
        }
 
        Assert.Equal(0, authCount);
    }
 
    [Fact]
    public void MapConnectionHandlerEndPointRoutingFindsAttributesOnHub()
    {
        var authCount = 0;
        using (var host = BuildWebHost<AuthConnectionHandler>("/path", options =>
        {
            authCount += options.AuthorizationData.Count;
        }))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/path/negotiate", endpoint.DisplayName);
                    Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
                },
                endpoint =>
                {
                    Assert.Equal("/path", endpoint.DisplayName);
                    Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
                });
        }
 
        Assert.Equal(0, authCount);
    }
 
    [Fact]
    public void MapConnectionHandlerEndPointRoutingFindsAttributesFromOptions()
    {
        var authCount = 0;
        using (var host = BuildWebHost<AuthConnectionHandler>("/path", options =>
        {
            authCount += options.AuthorizationData.Count;
            options.AuthorizationData.Add(new AuthorizeAttribute());
        }))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/path/negotiate", endpoint.DisplayName);
                    Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
                },
                endpoint =>
                {
                    Assert.Equal("/path", endpoint.DisplayName);
                    Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
                });
        }
 
        Assert.Equal(0, authCount);
    }
 
    [Fact]
    public void MapConnectionHandlerEndPointRoutingAppliesAttributesBeforeConventions()
    {
        void ConfigureRoutes(IEndpointRouteBuilder endpoints)
        {
            // This "Foo" policy should override the default auth attribute
            endpoints.MapConnectionHandler<AuthConnectionHandler>("/path")
                  .RequireAuthorization(new AuthorizeAttribute("Foo"));
        }
 
        using (var host = BuildWebHost(ConfigureRoutes))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/path/negotiate", endpoint.DisplayName);
                    Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
                        auth => { },
                        auth =>
                        {
                            Assert.Equal("Foo", auth?.Policy);
                        });
                },
                endpoint =>
                {
                    Assert.Equal("/path", endpoint.DisplayName);
                    Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
                        auth => { },
                        auth =>
                        {
                            Assert.Equal("Foo", auth?.Policy);
                        });
                });
        }
    }
 
    [Fact]
    public void MapConnectionHandlerEndPointRoutingAppliesNegotiateMetadata()
    {
        void ConfigureRoutes(IEndpointRouteBuilder endpoints)
        {
            endpoints.MapConnectionHandler<AuthConnectionHandler>("/path");
        }
 
        using (var host = BuildWebHost(ConfigureRoutes))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/path/negotiate", endpoint.DisplayName);
                    var metaData = endpoint.Metadata.GetMetadata<NegotiateMetadata>();
                    Assert.NotNull(metaData);
                    var optionsMetaData = endpoint.Metadata.GetMetadata<HttpConnectionDispatcherOptions>();
                    Assert.NotNull(optionsMetaData);
                },
                endpoint =>
                {
                    Assert.Equal("/path", endpoint.DisplayName);
                    Assert.Null(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
                });
        }
    }
 
    [Fact]
    public void MapConnectionHandlerNegotiateMetadataContainsOptions()
    {
        void ConfigureRoutes(IEndpointRouteBuilder endpoints)
        {
            endpoints.MapConnectionHandler<AuthConnectionHandler>("/path", options =>
            {
                options.Transports = HttpTransportType.ServerSentEvents;
                options.ApplicationMaxBufferSize = 2;
                options.CloseOnAuthenticationExpiration = true;
            });
        }
 
        using (var host = BuildWebHost(ConfigureRoutes))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/path/negotiate", endpoint.DisplayName);
                    var metaData = endpoint.Metadata.GetMetadata<NegotiateMetadata>();
                    Assert.NotNull(metaData);
                    var optionsMetaData = endpoint.Metadata.GetMetadata<HttpConnectionDispatcherOptions>();
                    Assert.NotNull(optionsMetaData);
                    Assert.Equal(HttpTransportType.ServerSentEvents, optionsMetaData.Transports);
                    Assert.Equal(2, optionsMetaData.ApplicationMaxBufferSize);
                    Assert.True(optionsMetaData.CloseOnAuthenticationExpiration);
                },
                endpoint =>
                {
                    Assert.Equal("/path", endpoint.DisplayName);
                    Assert.Null(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
                });
        }
    }
 
    [Fact]
    public void MapConnectionHandlerEndPointRoutingAppliesCorsMetadata()
    {
        void ConfigureRoutes(IEndpointRouteBuilder endpoints)
        {
            endpoints.MapConnectionHandler<CorsConnectionHandler>("/path");
        }
 
        using (var host = BuildWebHost(ConfigureRoutes))
        {
            host.Start();
 
            var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
            // We register 2 endpoints (/negotiate and /)
            Assert.Collection(dataSource.Endpoints,
                endpoint =>
                {
                    Assert.Equal("/path/negotiate", endpoint.DisplayName);
                    Assert.NotNull(endpoint.Metadata.GetMetadata<IEnableCorsAttribute>());
                },
                endpoint =>
                {
                    Assert.Equal("/path", endpoint.DisplayName);
                    Assert.NotNull(endpoint.Metadata.GetMetadata<IEnableCorsAttribute>());
                });
        }
    }
 
    [ConditionalFact]
    [WebSocketsSupportedCondition]
    public async Task MapConnectionHandlerWithWebSocketSubProtocolSetsProtocol()
    {
        using var host = BuildWebHost<MyConnectionHandler>("/socket",
            options => options.WebSockets.SubProtocolSelector = subprotocols =>
            {
                Assert.Equal(new[] { "protocol1", "protocol2" }, subprotocols.ToArray());
                return "protocol1";
            });
 
        await host.StartAsync();
 
        var feature = host.Services.GetService<IServer>().Features.Get<IServerAddressesFeature>();
        var address = feature.Addresses.First().Replace("http", "ws") + "/socket";
 
        var client = new ClientWebSocket();
        client.Options.AddSubProtocol("protocol1");
        client.Options.AddSubProtocol("protocol2");
        await client.ConnectAsync(new Uri(address), CancellationToken.None);
        Assert.Equal("protocol1", client.SubProtocol);
        await client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None).DefaultTimeout();
        var result = await client.ReceiveAsync(new ArraySegment<byte>(new byte[1024]), CancellationToken.None).DefaultTimeout();
        Assert.Equal(WebSocketMessageType.Close, result.MessageType);
    }
 
    [Fact]
    public void MapConnectionHandlerAddsDisableRequestTimeoutMetadata()
    {
        using var host = BuildWebHost<MyConnectionHandler>("/test", o => { });
        host.Start();
 
        var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
        // We register 2 endpoints (/negotiate and /)
        Assert.Collection(dataSource.Endpoints,
            endpoint =>
            {
                Assert.Equal("/test/negotiate", endpoint.DisplayName);
                Assert.Empty(endpoint.Metadata.GetOrderedMetadata<DisableRequestTimeoutAttribute>());
            },
            endpoint =>
            {
                Assert.Equal("/test", endpoint.DisplayName);
                Assert.Single(endpoint.Metadata.GetOrderedMetadata<DisableRequestTimeoutAttribute>());
            });
    }
 
    private class MyConnectionHandler : ConnectionHandler
    {
        public override async Task OnConnectedAsync(ConnectionContext connection)
        {
            while (true)
            {
                var result = await connection.Transport.Input.ReadAsync();
 
                if (result.IsCompleted)
                {
                    break;
                }
 
                // Consume nothing
                connection.Transport.Input.AdvanceTo(result.Buffer.Start);
            }
        }
    }
 
    [EnableCors]
    private class CorsConnectionHandler : ConnectionHandler
    {
        public override Task OnConnectedAsync(ConnectionContext connection)
        {
            throw new NotImplementedException();
        }
    }
 
    private class InheritedAuthConnectionHandler : AuthConnectionHandler
    {
        public override Task OnConnectedAsync(ConnectionContext connection)
        {
            throw new NotImplementedException();
        }
    }
 
    [Authorize]
    private class DoubleAuthConnectionHandler : AuthConnectionHandler
    {
    }
 
    [Authorize]
    private class AuthConnectionHandler : ConnectionHandler
    {
        public override Task OnConnectedAsync(ConnectionContext connection)
        {
            throw new NotImplementedException();
        }
    }
 
    private class TestRequirement : IAuthorizationRequirement
    {
    }
 
    private IHost BuildWebHost(Action<IEndpointRouteBuilder> configure)
    {
        return new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseKestrel()
                .ConfigureServices(services =>
                {
                    services.AddConnections();
                })
                .Configure(app =>
                {
                    app.UseRouting();
                    app.UseEndpoints(endpoints => configure(endpoints));
                })
                .UseUrls("http://127.0.0.1:0");
            })
            .Build();
    }
 
    private IHost BuildWebHost<TConnectionHandler>(string path, Action<HttpConnectionDispatcherOptions> configureOptions, Action<IEndpointConventionBuilder> configureEndpoints = null) where TConnectionHandler : ConnectionHandler
    {
        return new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseUrls("http://127.0.0.1:0")
                .UseKestrel()
                .ConfigureServices(services =>
                {
                    services.AddConnections();
                })
                .Configure(app =>
                {
                    app.UseRouting();
                    app.UseEndpoints(routes =>
                    {
                        var builder = routes.MapConnectionHandler<TConnectionHandler>(path, configureOptions);
                        if (configureEndpoints != null)
                        {
                            configureEndpoints(builder);
                        }
                    });
                })
                .ConfigureLogging(factory =>
                {
                    factory.AddXunit(_output, LogLevel.Trace);
                });
            })
            .Build();
    }
}