File: WebSocketsTests.cs
Web Access
Project: src\src\SignalR\common\Http.Connections\test\Microsoft.AspNetCore.Http.Connections.Tests.csproj (Microsoft.AspNetCore.Http.Connections.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Buffers;
using System.IO.Pipelines;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal;
using Microsoft.AspNetCore.Http.Connections.Internal.Transports;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Net.Http.Headers;
using Xunit;
 
namespace Microsoft.AspNetCore.Http.Connections.Tests;
 
public class WebSocketsTests : VerifiableLoggedTest
{
    // Using nameof with WebSocketMessageType because it is a GACed type and xunit can't serialize it
    [Theory]
    [InlineData(nameof(WebSocketMessageType.Text))]
    [InlineData(nameof(WebSocketMessageType.Binary))]
    public async Task ReceivedFramesAreWrittenToChannel(string webSocketMessageType)
    {
        using (StartVerifiableLog())
        {
            var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
            var connection = CreateHttpConnectionContext(pair, loggerName: "HttpConnectionContext1");
 
            using (var feature = new TestWebSocketConnectionFeature())
            {
                var ws = new WebSocketsServerTransport(new WebSocketOptions(), connection.Application, connection, LoggerFactory);
 
                // Give the server socket to the transport and run it
                var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
 
                // Run the client socket
                var client = feature.Client.ExecuteAndCaptureFramesAsync();
 
                // Send a frame, then close
                await feature.Client.SendAsync(
                    buffer: new ArraySegment<byte>(Encoding.UTF8.GetBytes("Hello")),
                    messageType: (WebSocketMessageType)Enum.Parse(typeof(WebSocketMessageType), webSocketMessageType),
                    endOfMessage: true,
                    cancellationToken: CancellationToken.None);
                await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
 
                var result = await connection.Transport.Input.ReadAsync();
                var buffer = result.Buffer;
                Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.ToArray()));
                connection.Transport.Input.AdvanceTo(buffer.End);
 
                connection.Transport.Output.Complete();
 
                // The transport should finish now
                await transport;
 
                // The connection should close after this, which means the client will get a close frame.
                var clientSummary = await client;
 
                Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.CloseStatus);
            }
        }
    }
 
    // Using nameof with WebSocketMessageType because it is a GACed type and xunit can't serialize it
    [Theory]
    [InlineData(TransferFormat.Text, nameof(WebSocketMessageType.Text))]
    [InlineData(TransferFormat.Binary, nameof(WebSocketMessageType.Binary))]
    public async Task WebSocketTransportSetsMessageTypeBasedOnTransferFormatFeature(TransferFormat transferFormat, string expectedMessageType)
    {
        using (StartVerifiableLog())
        {
            var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
            var connection = CreateHttpConnectionContext(pair, loggerName: "HttpConnectionContext1");
 
            using (var feature = new TestWebSocketConnectionFeature())
            {
                connection.ActiveFormat = transferFormat;
                var ws = new WebSocketsServerTransport(new WebSocketOptions(), connection.Application, connection, LoggerFactory);
 
                // Give the server socket to the transport and run it
                var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
 
                // Run the client socket
                var client = feature.Client.ExecuteAndCaptureFramesAsync();
 
                // Write to the output channel, and then complete it
                await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello"));
                connection.Transport.Output.Complete();
 
                // The client should finish now, as should the server
                var clientSummary = await client;
                await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
                await transport;
 
                Assert.Equal(1, clientSummary.Received.Count);
                Assert.True(clientSummary.Received[0].EndOfMessage);
                Assert.Equal((WebSocketMessageType)Enum.Parse(typeof(WebSocketMessageType), expectedMessageType), clientSummary.Received[0].MessageType);
                Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Buffer));
            }
        }
    }
 
    private HttpConnectionContext CreateHttpConnectionContext(DuplexPipe.DuplexPipePair pair, string loggerName = null)
    {
        return new HttpConnectionContext("foo", connectionToken: null, LoggerFactory.CreateLogger(loggerName ?? nameof(HttpConnectionContext)),
            metricsContext: default, pair.Transport, pair.Application, new(), useStatefulReconnect: false);
    }
 
    [Fact]
    public async Task TransportCommunicatesErrorToApplicationWhenClientDisconnectsAbnormally()
    {
        using (StartVerifiableLog())
        {
            var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
            var connection = CreateHttpConnectionContext(pair, loggerName: "HttpConnectionContext1");
 
            using (var feature = new TestWebSocketConnectionFeature())
            {
                async Task CompleteApplicationAfterTransportCompletes()
                {
                    try
                    {
                        // Wait until the transport completes so that we can end the application
                        var result = await connection.Transport.Input.ReadAsync();
                        connection.Transport.Input.AdvanceTo(result.Buffer.End);
                    }
                    catch (Exception ex)
                    {
                        Assert.IsType<WebSocketError>(ex);
                    }
                    finally
                    {
                        // Complete the application so that the connection unwinds without aborting
                        connection.Transport.Output.Complete();
                    }
                }
 
                var ws = new WebSocketsServerTransport(new WebSocketOptions(), connection.Application, connection, LoggerFactory);
 
                // Give the server socket to the transport and run it
                var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
 
                // Run the client socket
                var client = feature.Client.ExecuteAndCaptureFramesAsync();
 
                // When the close frame is received, we complete the application so the send
                // loop unwinds
                _ = CompleteApplicationAfterTransportCompletes();
 
                // Terminate the client to server channel with an exception
                feature.Client.SendAbort();
 
                // Wait for the transport
                await transport.DefaultTimeout();
 
                await client.DefaultTimeout();
            }
        }
    }
 
    [Fact]
    public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails()
    {
        using (StartVerifiableLog())
        {
            var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
            var connection = CreateHttpConnectionContext(pair);
 
            using (var feature = new TestWebSocketConnectionFeature())
            {
                var ws = new WebSocketsServerTransport(new WebSocketOptions(), connection.Application, connection, LoggerFactory);
 
                // Give the server socket to the transport and run it
                var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
 
                // Run the client socket
                var client = feature.Client.ExecuteAndCaptureFramesAsync();
 
                // Fail in the app
                connection.Transport.Output.Complete(new InvalidOperationException("Catastrophic failure."));
                var clientSummary = await client.DefaultTimeout();
                Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.CloseStatus);
 
                // Close from the client
                await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
 
                await transport.DefaultTimeout();
            }
        }
    }
 
    [Fact]
    public async Task TransportClosesOnCloseTimeoutIfClientDoesNotSendCloseFrame()
    {
        using (StartVerifiableLog())
        {
            var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
            var connection = CreateHttpConnectionContext(pair);
 
            using (var feature = new TestWebSocketConnectionFeature())
            {
                var options = new WebSocketOptions()
                {
                    CloseTimeout = TimeSpan.FromSeconds(1)
                };
 
                var ws = new WebSocketsServerTransport(options, connection.Application, connection, LoggerFactory);
 
                var serverSocket = await feature.AcceptAsync();
                // Give the server socket to the transport and run it
                var transport = ws.ProcessSocketAsync(serverSocket);
 
                // End the app
                connection.Transport.Output.Complete();
 
                await transport.DefaultTimeout(TimeSpan.FromSeconds(10));
 
                // Now we're closed
                Assert.Equal(WebSocketState.Aborted, serverSocket.State);
 
                serverSocket.Dispose();
            }
        }
    }
 
    [Fact]
    public async Task TransportFailsOnTimeoutWithErrorWhenApplicationFailsAndClientDoesNotSendCloseFrame()
    {
        using (StartVerifiableLog())
        {
            var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
            var connection = CreateHttpConnectionContext(pair);
 
            using (var feature = new TestWebSocketConnectionFeature())
            {
                var options = new WebSocketOptions
                {
                    CloseTimeout = TimeSpan.FromSeconds(1)
                };
 
                var ws = new WebSocketsServerTransport(options, connection.Application, connection, LoggerFactory);
 
                var serverSocket = await feature.AcceptAsync();
                // Give the server socket to the transport and run it
                var transport = ws.ProcessSocketAsync(serverSocket);
 
                // Run the client socket
                var client = feature.Client.ExecuteAndCaptureFramesAsync();
 
                // fail the client to server channel
                connection.Transport.Output.Complete(new Exception());
 
                await transport.DefaultTimeout();
 
                Assert.Equal(WebSocketState.Aborted, serverSocket.State);
            }
        }
    }
 
    [Fact]
    public async Task ServerGracefullyClosesWhenApplicationEndsThenClientSendsCloseFrame()
    {
        using (StartVerifiableLog())
        {
            var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
            var connection = CreateHttpConnectionContext(pair);
 
            using (var feature = new TestWebSocketConnectionFeature())
            {
                var options = new WebSocketOptions
                {
                    // We want to verify behavior without timeout affecting it
                    CloseTimeout = TimeSpan.FromSeconds(20)
                };
 
                var ws = new WebSocketsServerTransport(options, connection.Application, connection, LoggerFactory);
 
                var serverSocket = await feature.AcceptAsync();
                // Give the server socket to the transport and run it
                var transport = ws.ProcessSocketAsync(serverSocket);
 
                // Run the client socket
                var client = feature.Client.ExecuteAndCaptureFramesAsync();
 
                // close the client to server channel
                connection.Transport.Output.Complete();
 
                _ = await client.DefaultTimeout();
 
                await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).DefaultTimeout();
 
                await transport.DefaultTimeout();
 
                Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus);
            }
        }
    }
 
    [Fact]
    public async Task ServerGracefullyClosesWhenClientSendsCloseFrameThenApplicationEnds()
    {
        using (StartVerifiableLog())
        {
            var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
            var connection = CreateHttpConnectionContext(pair);
 
            using (var feature = new TestWebSocketConnectionFeature())
            {
                var options = new WebSocketOptions
                {
                    // We want to verify behavior without timeout affecting it
                    CloseTimeout = TimeSpan.FromSeconds(20)
                };
 
                var ws = new WebSocketsServerTransport(options, connection.Application, connection, LoggerFactory);
 
                var serverSocket = await feature.AcceptAsync();
                // Give the server socket to the transport and run it
                var transport = ws.ProcessSocketAsync(serverSocket);
 
                // Run the client socket
                var client = feature.Client.ExecuteAndCaptureFramesAsync();
 
                await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).DefaultTimeout();
 
                // close the client to server channel
                connection.Transport.Output.Complete();
 
                _ = await client.DefaultTimeout();
 
                await transport.DefaultTimeout();
 
                Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus);
            }
        }
    }
 
    [Fact]
    public async Task SubProtocolSelectorIsUsedToSelectSubProtocol()
    {
        const string ExpectedSubProtocol = "expected";
        var providedSubProtocols = new[] { "provided1", "provided2" };
 
        using (StartVerifiableLog())
        {
            var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
            var connection = CreateHttpConnectionContext(pair);
 
            using (var feature = new TestWebSocketConnectionFeature())
            {
                var options = new WebSocketOptions
                {
                    // We want to verify behavior without timeout affecting it
                    CloseTimeout = TimeSpan.FromSeconds(20),
                    SubProtocolSelector = protocols =>
                    {
                        Assert.Equal(providedSubProtocols, protocols.ToArray());
                        return ExpectedSubProtocol;
                    },
                };
 
                var ws = new WebSocketsServerTransport(options, connection.Application, connection, LoggerFactory);
 
                // Create an HttpContext
                var context = new DefaultHttpContext();
                context.Request.Headers.Add(HeaderNames.WebSocketSubProtocols, providedSubProtocols.ToArray());
                context.Features.Set<IHttpWebSocketFeature>(feature);
                var transport = ws.ProcessRequestAsync(context, CancellationToken.None);
 
                await feature.Accepted.OrThrowIfOtherFails(transport);
 
                // Assert the feature got the right subprotocol
                Assert.Equal(ExpectedSubProtocol, feature.SubProtocol);
 
                // Run the client socket
                var client = feature.Client.ExecuteAndCaptureFramesAsync();
 
                await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).DefaultTimeout();
 
                // close the client to server channel
                connection.Transport.Output.Complete();
 
                _ = await client.DefaultTimeout();
 
                await transport.DefaultTimeout();
            }
        }
    }
 
    [Fact]
    public async Task MultiSegmentSendWillNotSendEmptyEndOfMessageFrame()
    {
        using (var feature = new TestWebSocketConnectionFeature())
        {
            var serverSocket = await feature.AcceptAsync();
            var sequence = ReadOnlySequenceFactory.CreateSegments(new byte[] { 1 }, new byte[] { 15 });
            Assert.False(sequence.IsSingleSegment);
 
            await serverSocket.SendAsync(sequence, WebSocketMessageType.Text);
 
            // Run the client socket
            var client = feature.Client.ExecuteAndCaptureFramesAsync();
 
            await serverSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", default);
 
            var messages = await client.DefaultTimeout();
            Assert.Equal(2, messages.Received.Count);
 
            // First message: 1 byte, endOfMessage false
            Assert.Single(messages.Received[0].Buffer);
            Assert.Equal(1, messages.Received[0].Buffer[0]);
            Assert.False(messages.Received[0].EndOfMessage);
 
            // Second message: 1 byte, endOfMessage true
            Assert.Single(messages.Received[1].Buffer);
            Assert.Equal(15, messages.Received[1].Buffer[0]);
            Assert.True(messages.Received[1].EndOfMessage);
        }
    }
}