File: Internal\WebTransport\WebTransportSession.cs
Web Access
Project: src\src\Servers\Kestrel\Core\src\Microsoft.AspNetCore.Server.Kestrel.Core.csproj (Microsoft.AspNetCore.Server.Kestrel.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Net.Http;
using System.Threading.Channels;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3;
using Microsoft.AspNetCore.Server.Kestrel.Core.WebTransport;
 
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.WebTransport;
 
#pragma warning disable CA2252 // WebTransport is a preview feature
internal sealed class WebTransportSession : IWebTransportSession
{
    private static readonly IStreamDirectionFeature _outputStreamDirectionFeature = new DefaultStreamDirectionFeature(canRead: false, canWrite: true);
 
    private readonly CancellationTokenRegistration _connectionClosedRegistration;
 
    // stores all created streams (pending or accepted)
    private readonly ConcurrentDictionary<long, WebTransportStream> _openStreams = new();
    // stores all pending streams that have not been accepted yet
    private readonly Channel<WebTransportStream> _pendingStreams;
 
    private readonly Http3Connection _connection;
    private readonly Http3Stream _connectStream = default!;
    private bool _isClosing;
 
    private static readonly ReadOnlyMemory<byte> OutputStreamHeader = new(new byte[] {
            0x40 /*quic variable-length integer length*/,
            (byte)Http3StreamType.WebTransportUnidirectional,
            0x00 /*body*/});
 
    internal const string WebTransportProtocolValue = "webtransport";
    internal const string VersionEnabledIndicator = "1";
    internal const string SecPrefix = "sec-webtransport-http3-";
    internal const string VersionHeaderPrefix = $"{SecPrefix}draft";
    internal const string CurrentSupportedVersionSuffix = "draft02";
    internal const string CurrentSupportedVersion = $"{SecPrefix}{CurrentSupportedVersionSuffix}";
 
    public long SessionId => _connectStream.StreamId;
 
    internal WebTransportSession(Http3Connection connection, Http3Stream connectStream)
    {
        _connection = connection;
        _connectStream = connectStream;
        _isClosing = false;
        // unbounded as limits to number of streams is enforced elsewhere
        _pendingStreams = Channel.CreateUnbounded<WebTransportStream>();
 
        // listener to abort if this connection is closed
        _connectionClosedRegistration = connection._multiplexedContext.ConnectionClosed.Register(static state =>
        {
            var session = (WebTransportSession)state!;
            session.OnClientConnectionClosed();
        }, this);
    }
 
    void IWebTransportSession.Abort(int errorCode)
    {
        Abort(new(), (Http3ErrorCode)errorCode);
    }
 
    internal void OnClientConnectionClosed()
    {
        if (_isClosing)
        {
            return;
        }
 
        _isClosing = true;
 
        _connectionClosedRegistration.Dispose();
 
        lock (_openStreams)
        {
            foreach (var stream in _openStreams)
            {
                stream.Value.DisposeAsync().AsTask().GetAwaiter().GetResult();
            }
 
            _openStreams.Clear();
        }
 
        _pendingStreams.Writer.Complete();
    }
 
    internal void Abort(ConnectionAbortedException exception, Http3ErrorCode error)
    {
        if (_isClosing)
        {
            return;
        }
 
        _isClosing = true;
 
        _connectionClosedRegistration.Dispose();
 
        lock (_openStreams)
        {
            _connectStream.Abort(exception, error);
            foreach (var stream in _openStreams)
            {
                if (exception.InnerException is not null)
                {
                    stream.Value.Abort(new ConnectionAbortedException(exception.Message, exception.InnerException));
                }
                else
                {
                    stream.Value.Abort(new ConnectionAbortedException(exception.Message));
                }
            }
            _openStreams.Clear();
        }
 
        _pendingStreams.Writer.Complete();
    }
 
    public async ValueTask<ConnectionContext?> OpenUnidirectionalStreamAsync(CancellationToken cancellationToken)
    {
        if (_isClosing)
        {
            return null;
        }
        // create the stream
        var features = new FeatureCollection();
        features.Set(_outputStreamDirectionFeature);
        var connectionContext = await _connection._multiplexedContext.ConnectAsync(features, cancellationToken);
        var streamContext = _connection.CreateHttpStreamContext(connectionContext);
        var stream = new WebTransportStream(streamContext, WebTransportStreamType.Output);
 
        var success = _openStreams.TryAdd(stream.StreamId, stream);
        Debug.Assert(success);
 
        // send the stream header
        // https://ietf-wg-webtrans.github.io/draft-ietf-webtrans-http3/draft-ietf-webtrans-http3.html#name-unidirectional-streams
        await stream.Transport.Output.WriteAsync(OutputStreamHeader, cancellationToken);
 
        return stream;
    }
 
    internal void AddStream(WebTransportStream stream)
    {
        if (_isClosing)
        {
            stream.Abort();
            return;
        }
 
        var addedToOpenStreams = _openStreams.TryAdd(stream.StreamId, stream);
 
        if (!addedToOpenStreams || !_pendingStreams.Writer.TryWrite(stream))
        {
            if (addedToOpenStreams)
            {
                _openStreams.Remove(stream.StreamId, out _);
            }
 
            stream.Abort(new ConnectionAbortedException(CoreStrings.WebTransportFailedToAddStreamToPendingQueue));
        }
    }
 
    public async ValueTask<ConnectionContext?> AcceptStreamAsync(CancellationToken cancellationToken)
    {
        if (_isClosing)
        {
            return null;
        }
 
        try
        {
            return await _pendingStreams.Reader.ReadAsync(cancellationToken);
        }
        catch (ChannelClosedException)
        {
            return null;
        }
    }
 
    internal bool TryRemoveStream(long streamId)
    {
        var success = _openStreams.Remove(streamId, out var stream);
 
        if (success && stream is not null)
        {
            stream.DisposeAsync().AsTask().GetAwaiter().GetResult();
        }
 
        return success;
    }
}
#pragma warning restore CA2252 // WebTransport is a preview feature