File: QuicTestHelpers.cs
Web Access
Project: src\src\Servers\Kestrel\Transport.Quic\test\Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Tests.csproj (Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Quic;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Https;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Internal;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Xunit;
 
namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Tests;
 
internal static class QuicTestHelpers
{
    private static readonly byte[] TestData = Encoding.UTF8.GetBytes("Hello world");
 
    public static QuicTransportFactory CreateTransportFactory(
        ILoggerFactory loggerFactory = null,
        TimeProvider timeProvider = null,
        long defaultCloseErrorCode = 0)
    {
        var quicTransportOptions = new QuicTransportOptions();
        quicTransportOptions.MaxBidirectionalStreamCount = 200;
        quicTransportOptions.MaxUnidirectionalStreamCount = 200;
        quicTransportOptions.DefaultCloseErrorCode = defaultCloseErrorCode;
        if (timeProvider != null)
        {
            quicTransportOptions.TimeProvider = timeProvider;
        }
 
        return new QuicTransportFactory(loggerFactory ?? NullLoggerFactory.Instance, Options.Create(quicTransportOptions));
    }
 
    public static async Task<QuicConnectionListener> CreateConnectionListenerFactory(
        ILoggerFactory loggerFactory = null,
        TimeProvider timeProvider = null,
        bool clientCertificateRequired = false,
        long defaultCloseErrorCode = 0,
        int port = 0)
    {
        var transportFactory = CreateTransportFactory(
            loggerFactory,
            timeProvider,
            defaultCloseErrorCode: defaultCloseErrorCode);
 
        var endpoint = new IPEndPoint(IPAddress.Parse("127.0.0.1"), port);
 
        var features = CreateBindAsyncFeatures(clientCertificateRequired);
        return (QuicConnectionListener)await transportFactory.BindAsync(endpoint, features, cancellationToken: CancellationToken.None);
    }
 
    public static async Task<QuicConnectionListener> CreateConnectionListenerFactory(
        TlsConnectionCallbackOptions tlsConnectionOptions,
        ILoggerFactory loggerFactory = null,
        TimeProvider timeProvider = null,
        int port = 0)
    {
        var transportFactory = CreateTransportFactory(loggerFactory, timeProvider);
 
        var endpoint = new IPEndPoint(IPAddress.Loopback, port);
 
        var features = new FeatureCollection();
        features.Set(tlsConnectionOptions);
        return (QuicConnectionListener)await transportFactory.BindAsync(endpoint, features, cancellationToken: CancellationToken.None);
    }
 
    public static FeatureCollection CreateBindAsyncFeatures(bool clientCertificateRequired = false)
    {
        var cert = TestResources.GetTestCertificate();
 
        var sslServerAuthenticationOptions = new SslServerAuthenticationOptions();
        sslServerAuthenticationOptions.ApplicationProtocols = new List<SslApplicationProtocol>() { SslApplicationProtocol.Http3 };
        sslServerAuthenticationOptions.ServerCertificate = cert;
        sslServerAuthenticationOptions.RemoteCertificateValidationCallback = RemoteCertificateValidationCallback;
        sslServerAuthenticationOptions.ClientCertificateRequired = clientCertificateRequired;
 
        var features = new FeatureCollection();
        features.Set(new TlsConnectionCallbackOptions
        {
            ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols,
            OnConnection = (context, cancellationToken) => ValueTask.FromResult(sslServerAuthenticationOptions)
        });
 
        return features;
    }
 
    public static async ValueTask<MultiplexedConnectionContext> AcceptAndAddFeatureAsync(this IMultiplexedConnectionListener listener)
    {
        var connection = await listener.AcceptAsync();
        connection?.Features.Set<IConnectionHeartbeatFeature>(new TestConnectionHeartbeatFeature());
        return connection;
    }
 
    private class TestConnectionHeartbeatFeature : IConnectionHeartbeatFeature
    {
        public void OnHeartbeat(Action<object> action, object state)
        {
        }
    }
 
    private static bool RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
    {
        return true;
    }
 
    public static QuicClientConnectionOptions CreateClientConnectionOptions(EndPoint remoteEndPoint, bool? ignoreInvalidCertificate = null)
    {
        var options = new QuicClientConnectionOptions
        {
            MaxInboundBidirectionalStreams = 200,
            MaxInboundUnidirectionalStreams = 200,
            RemoteEndPoint = remoteEndPoint,
            ClientAuthenticationOptions = new SslClientAuthenticationOptions
            {
                ApplicationProtocols = new List<SslApplicationProtocol>
                {
                    SslApplicationProtocol.Http3
                }
            },
            DefaultStreamErrorCode = 0,
            DefaultCloseErrorCode = 0,
        };
        if (ignoreInvalidCertificate ?? true)
        {
            options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = RemoteCertificateValidationCallback;
        }
        return options;
    }
 
    public static async Task<QuicStreamContext> CreateAndCompleteBidirectionalStreamGracefully(QuicConnection clientConnection, MultiplexedConnectionContext serverConnection, ILogger logger)
    {
        logger.LogInformation("Client starting stream.");
        var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
 
        logger.LogInformation("Client sending data.");
        await clientStream.WriteAsync(TestData, completeWrites: true).DefaultTimeout();
 
        logger.LogInformation("Server accepting stream.");
        var serverStream = await serverConnection.AcceptAsync().DefaultTimeout();
 
        logger.LogInformation("Server reading data.");
        var readResult = await serverStream.Transport.Input.ReadAtLeastAsync(TestData.Length).DefaultTimeout();
        serverStream.Transport.Input.AdvanceTo(readResult.Buffer.End);
 
        // Input should be completed.
        readResult = await serverStream.Transport.Input.ReadAsync();
        Assert.True(readResult.IsCompleted);
 
        // Complete reading and writing.
        logger.LogInformation("Server completing input and output.");
        await serverStream.Transport.Input.CompleteAsync();
        await serverStream.Transport.Output.CompleteAsync();
 
        var quicStreamContext = Assert.IsType<QuicStreamContext>(serverStream);
 
        // Both send and receive loops have exited.
        logger.LogInformation("Server verifying stream is finished.");
        await quicStreamContext._processingTask.DefaultTimeout();
        Assert.True(quicStreamContext.CanWrite);
        Assert.True(quicStreamContext.CanRead);
 
        logger.LogInformation("Server disposing stream.");
        await quicStreamContext.DisposeAsync();
        quicStreamContext.Dispose();
 
        return quicStreamContext;
    }
}