File: Internal\HttpConnectionDispatcher.cs
Web Access
Project: src\src\SignalR\common\Http.Connections\src\Microsoft.AspNetCore.Http.Connections.csproj (Microsoft.AspNetCore.Http.Connections)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Security.Claims;
using System.Security.Principal;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal.Transports;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Timeouts;
using Microsoft.AspNetCore.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
 
namespace Microsoft.AspNetCore.Http.Connections.Internal;
 
internal sealed partial class HttpConnectionDispatcher
{
    private static readonly AvailableTransport _webSocketAvailableTransport =
        new AvailableTransport
        {
            Transport = nameof(HttpTransportType.WebSockets),
            TransferFormats = new List<string> { nameof(TransferFormat.Text), nameof(TransferFormat.Binary) }
        };
 
    private static readonly AvailableTransport _serverSentEventsAvailableTransport =
        new AvailableTransport
        {
            Transport = nameof(HttpTransportType.ServerSentEvents),
            TransferFormats = new List<string> { nameof(TransferFormat.Text) }
        };
 
    private static readonly AvailableTransport _longPollingAvailableTransport =
        new AvailableTransport
        {
            Transport = nameof(HttpTransportType.LongPolling),
            TransferFormats = new List<string> { nameof(TransferFormat.Text), nameof(TransferFormat.Binary) }
        };
 
    private readonly HttpConnectionManager _manager;
    private readonly ILoggerFactory _loggerFactory;
    private readonly HttpConnectionsMetrics _metrics;
    private readonly ILogger _logger;
    private const int _protocolVersion = 1;
 
    // This should be kept in sync with CookieAuthenticationHandler
    private const string HeaderValueNoCache = "no-cache";
    private const string HeaderValueNoCacheNoStore = "no-cache, no-store";
    private const string HeaderValueEpochDate = "Thu, 01 Jan 1970 00:00:00 GMT";
 
    public HttpConnectionDispatcher(HttpConnectionManager manager, ILoggerFactory loggerFactory, HttpConnectionsMetrics metrics)
    {
        _manager = manager;
        _loggerFactory = loggerFactory;
        _metrics = metrics;
        _logger = _loggerFactory.CreateLogger<HttpConnectionDispatcher>();
    }
 
    public async Task ExecuteAsync(HttpContext context, HttpConnectionDispatcherOptions options, ConnectionDelegate connectionDelegate)
    {
        // Create the log scope and attempt to pass the Connection ID to it so as many logs as possible contain
        // the Connection ID metadata. If this is the negotiate request then the Connection ID for the scope will
        // be set a little later.
 
        HttpConnectionContext? connectionContext = null;
        var connectionToken = GetConnectionToken(context);
 
        if (!StringValues.IsNullOrEmpty(connectionToken))
        {
            // Use ToString; IsNullOrEmpty doesn't tell the compiler anything about implicit conversion to string.
            _manager.TryGetConnection(connectionToken.ToString(), out connectionContext);
        }
 
        var logScope = new ConnectionLogScope(connectionContext?.ConnectionId);
        using (_logger.BeginScope(logScope))
        {
            if (HttpMethods.IsPost(context.Request.Method))
            {
                // POST /{path}
                await ProcessSend(context);
            }
            else if (HttpMethods.IsGet(context.Request.Method) || HttpMethods.IsConnect(context.Request.Method))
            {
                // GET /{path}
                await ExecuteAsync(context, connectionDelegate, options, logScope);
            }
            else if (HttpMethods.IsDelete(context.Request.Method))
            {
                // DELETE /{path}
                await ProcessDeleteAsync(context);
            }
            else
            {
                context.Response.ContentType = "text/plain";
                context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed;
            }
        }
    }
 
    public async Task ExecuteNegotiateAsync(HttpContext context, HttpConnectionDispatcherOptions options)
    {
        // Create the log scope and the scope connectionId param will be set when the connection is created.
        var logScope = new ConnectionLogScope(connectionId: string.Empty);
        using (_logger.BeginScope(logScope))
        {
            if (HttpMethods.IsPost(context.Request.Method))
            {
                // POST /{path}/negotiate
                await ProcessNegotiate(context, options, logScope);
            }
            else
            {
                context.Response.ContentType = "text/plain";
                context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed;
            }
        }
    }
 
    private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connectionDelegate, HttpConnectionDispatcherOptions options, ConnectionLogScope logScope)
    {
        // set a tag to allow Application Performance Management tools to differentiate long running requests for reporting purposes
        context.Features.Get<IHttpActivityFeature>()?.Activity.AddTag("http.long_running", "true");
 
        var supportedTransports = options.Transports;
 
        // Server sent events transport
        // GET /{path}
        // Accept: text/event-stream
        var headers = context.Request.GetTypedHeaders();
        if (headers.Accept?.Contains(new Net.Http.Headers.MediaTypeHeaderValue("text/event-stream")) == true)
        {
            // Connection must already exist
            var connection = await GetConnectionAsync(context);
            if (connection == null)
            {
                // No such connection, GetConnection already set the response status code
                return;
            }
 
            if (!await EnsureConnectionStateAsync(connection, context, HttpTransportType.ServerSentEvents, supportedTransports, logScope))
            {
                // Bad connection state. It's already set the response status code.
                return;
            }
 
            Log.EstablishedConnection(_logger);
 
            // ServerSentEvents is a text protocol only
            connection.SupportedFormats = TransferFormat.Text;
 
            // We only need to provide the Input channel since writing to the application is handled through /send.
            var sse = new ServerSentEventsServerTransport(connection.Application.Input, connection.ConnectionId, connection, _loggerFactory);
 
            if (connection.TryActivatePersistentConnection(connectionDelegate, sse, Task.CompletedTask, context, _logger))
            {
                await DoPersistentConnection(connection);
            }
        }
        else
        {
            // GET /{path} maps to long polling or WebSockets
 
            HttpConnectionContext? connection;
            var transport = HttpTransportType.LongPolling;
            if (context.WebSockets.IsWebSocketRequest)
            {
                transport = HttpTransportType.WebSockets;
                connection = await GetOrCreateConnectionAsync(context, options);
 
                if (connection is not null)
                {
                    Log.EstablishedConnection(_logger);
 
                    // Allow the reads to be canceled
                    connection.Cancellation ??= new CancellationTokenSource();
                }
            }
            else
            {
                AddNoCacheHeaders(context.Response);
                // Connection must already exist
                connection = await GetConnectionAsync(context);
            }
 
            if (connection == null)
            {
                // No such connection, GetConnection already set the response status code
                return;
            }
 
            if (!await EnsureConnectionStateAsync(connection, context, transport, supportedTransports, logScope))
            {
                // Bad connection state. It's already set the response status code.
                return;
            }
 
            if (connection.TransportType != HttpTransportType.WebSockets || connection.UseStatefulReconnect)
            {
                if (!await connection.CancelPreviousPoll(context))
                {
                    // Connection closed. It's already set the response status code.
                    return;
                }
            }
 
            // Create a new Tcs every poll to keep track of the poll finishing, so we can properly wait on previous polls
            var currentRequestTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
 
            var reconnectTask = Task.CompletedTask;
 
            switch (transport)
            {
                case HttpTransportType.None:
                    break;
                case HttpTransportType.WebSockets:
                    var isReconnect = connection.ApplicationTask is not null;
                    var ws = new WebSocketsServerTransport(options.WebSockets, connection.Application, connection, _loggerFactory);
                    if (!connection.TryActivatePersistentConnection(connectionDelegate, ws, currentRequestTcs.Task, context, _logger))
                    {
                        return;
                    }
 
                    if (connection.UseStatefulReconnect && isReconnect)
                    {
                        // Should call this after the transport has started, otherwise we'll be writing to a Pipe that isn't being read from
                        reconnectTask = connection.NotifyOnReconnect?.Invoke(connection.Transport.Output) ?? Task.CompletedTask;
                    }
                    break;
                case HttpTransportType.LongPolling:
                    if (!connection.TryActivateLongPollingConnection(
                        connectionDelegate, context, options.LongPolling.PollTimeout,
                        currentRequestTcs.Task, _loggerFactory, _logger))
                    {
                        return;
                    }
                    break;
                default:
                    break;
            }
 
            context.Features.Get<IHttpRequestTimeoutFeature>()?.DisableTimeout();
 
            try
            {
                await reconnectTask;
            }
            catch (Exception ex)
            {
                // MessageBuffer shouldn't throw from the callback
                // But users can technically add a callback, we don't want to trust them not to throw
                Log.NotifyOnReconnectError(_logger, ex);
            }
 
            var resultTask = await Task.WhenAny(connection.ApplicationTask!, connection.TransportTask!);
 
            try
            {
                // If the application ended before the transport task then we potentially need to end the connection
                if (resultTask == connection.ApplicationTask)
                {
                    // Complete the transport (notifying it of the application error if there is one)
                    connection.Transport.Output.Complete(connection.ApplicationTask.Exception);
 
                    // Wait for the transport to run
                    // Ignore exceptions, it has been logged if there is one and the application has finished
                    // So there is no one to give the exception to
                    await ((Task)connection.TransportTask!).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
 
                    // If the status code is a 204 it means the connection is done
                    if (context.Response.StatusCode == StatusCodes.Status204NoContent)
                    {
                        // Cancel current request to release any waiting poll and let dispose acquire the lock
                        currentRequestTcs.TrySetCanceled();
 
                        // We should be able to safely dispose because there's no more data being written
                        // We don't need to wait for close here since we've already waited for both sides
                        await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false, HttpConnectionStopStatus.NormalClosure);
                    }
                    else
                    {
                        if (transport != HttpTransportType.LongPolling)
                        {
                            await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false, HttpConnectionStopStatus.NormalClosure);
                        }
                        else
                        {
                            // Only allow repoll if we aren't removing the connection.
                            connection.MarkInactive();
                        }
                    }
                }
                else if (resultTask.IsFaulted || resultTask.IsCanceled)
                {
                    // Cancel current request to release any waiting poll and let dispose acquire the lock
                    currentRequestTcs.TrySetCanceled();
                    // We should be able to safely dispose because there's no more data being written
                    // We don't need to wait for close here since we've already waited for both sides
                    await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false, HttpConnectionStopStatus.NormalClosure);
                }
                else
                {
                    // If false then the transport was ungracefully closed, this can mean a temporary network disconnection
                    // We'll mark the connection as inactive and allow the connection to reconnect if that's the case.
                    if (await connection.TransportTask!
                        // If acks aren't enabled we can close the connection immediately (not LongPolling)
                        || !connection.ClientReconnectExpected())
                    {
                        await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure);
                    }
                    else
                    {
                        // Only allow repoll if we aren't removing the connection.
                        connection.MarkInactive();
                    }
                }
            }
            finally
            {
                // Artificial task queue
                // This will cause incoming polls to wait until the previous poll has finished updating internal state info
                currentRequestTcs.TrySetResult();
            }
        }
    }
 
    private async Task DoPersistentConnection(HttpConnectionContext connection)
    {
        // Wait for any of them to end
        await Task.WhenAny(connection.ApplicationTask!, connection.TransportTask!);
 
        await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure);
    }
 
    private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatcherOptions options, ConnectionLogScope logScope)
    {
        context.Response.ContentType = "application/json";
        string? error = null;
        int clientProtocolVersion = 0;
        if (context.Request.Query.TryGetValue("NegotiateVersion", out var queryStringVersion))
        {
            // Set the negotiate response to the protocol we use.
            var queryStringVersionValue = queryStringVersion.ToString();
            if (!int.TryParse(queryStringVersionValue, out clientProtocolVersion))
            {
                error = $"The client requested a non-integer protocol version.";
                Log.InvalidNegotiateProtocolVersion(_logger, queryStringVersionValue);
            }
            else if (clientProtocolVersion < options.MinimumProtocolVersion)
            {
                error = $"The client requested version '{clientProtocolVersion}', but the server does not support this version.";
                Log.NegotiateProtocolVersionMismatch(_logger, clientProtocolVersion);
            }
            else if (clientProtocolVersion > _protocolVersion)
            {
                clientProtocolVersion = _protocolVersion;
            }
        }
        else if (options.MinimumProtocolVersion > 0)
        {
            // NegotiateVersion wasn't parsed meaning the client requests version 0.
            error = $"The client requested version '0', but the server does not support this version.";
            Log.NegotiateProtocolVersionMismatch(_logger, 0);
        }
 
        var useStatefulReconnect = false;
        if (options.AllowStatefulReconnects == true && context.Request.Query.TryGetValue("UseStatefulReconnect", out var useStatefulReconnectValue))
        {
            var useStatefulReconnectStringValue = useStatefulReconnectValue.ToString();
            bool.TryParse(useStatefulReconnectStringValue, out useStatefulReconnect);
        }
 
        // Establish the connection
        HttpConnectionContext? connection = null;
        if (error == null)
        {
            connection = CreateConnection(options, clientProtocolVersion, useStatefulReconnect);
        }
 
        // Set the Connection ID on the logging scope so that logs from now on will have the
        // Connection ID metadata set.
        logScope.ConnectionId = connection?.ConnectionId;
 
        // Don't use thread static instance here because writer is used with async
        var writer = new MemoryBufferWriter();
 
        try
        {
            // Get the bytes for the connection id
            WriteNegotiatePayload(writer, connection?.ConnectionId, connection?.ConnectionToken, context, options,
                clientProtocolVersion, error, useStatefulReconnect);
 
            Log.NegotiationRequest(_logger);
 
            // Write it out to the response with the right content length
            context.Response.ContentLength = writer.Length;
            await writer.CopyToAsync(context.Response.Body);
        }
        finally
        {
            writer.Reset();
        }
    }
 
    private static void WriteNegotiatePayload(IBufferWriter<byte> writer, string? connectionId, string? connectionToken, HttpContext context, HttpConnectionDispatcherOptions options,
        int clientProtocolVersion, string? error, bool useStatefulReconnect)
    {
        var response = new NegotiationResponse();
 
        if (!string.IsNullOrEmpty(error))
        {
            response.Error = error;
            NegotiateProtocol.WriteResponse(response, writer);
            return;
        }
 
        response.Version = clientProtocolVersion;
        response.ConnectionId = connectionId;
        response.ConnectionToken = connectionToken;
        response.AvailableTransports = new List<AvailableTransport>();
        response.UseStatefulReconnect = useStatefulReconnect;
 
        if ((options.Transports & HttpTransportType.WebSockets) != 0 && ServerHasWebSockets(context.Features))
        {
            response.AvailableTransports.Add(_webSocketAvailableTransport);
        }
 
        if ((options.Transports & HttpTransportType.ServerSentEvents) != 0)
        {
            response.AvailableTransports.Add(_serverSentEventsAvailableTransport);
        }
 
        if ((options.Transports & HttpTransportType.LongPolling) != 0)
        {
            response.AvailableTransports.Add(_longPollingAvailableTransport);
        }
 
        NegotiateProtocol.WriteResponse(response, writer);
    }
 
    private static bool ServerHasWebSockets(IFeatureCollection features)
    {
        return features.Get<IHttpWebSocketFeature>() != null;
    }
 
    private static StringValues GetConnectionToken(HttpContext context) => context.Request.Query["id"];
 
    private async Task ProcessSend(HttpContext context)
    {
        var connection = await GetConnectionAsync(context);
        if (connection == null)
        {
            // No such connection, GetConnection already set the response status code
            return;
        }
 
        context.Response.ContentType = "text/plain";
 
        if (connection.TransportType == HttpTransportType.WebSockets)
        {
            Log.PostNotAllowedForWebSockets(_logger);
            context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed;
            await context.Response.WriteAsync("POST requests are not allowed for WebSocket connections.");
            return;
        }
 
        const int bufferSize = 4096;
 
        await connection.WriteLock.WaitAsync();
 
        try
        {
            if (connection.Status == HttpConnectionStatus.Disposed)
            {
                Log.ConnectionDisposed(_logger, connection.ConnectionId);
 
                // The connection was disposed
                context.Response.StatusCode = StatusCodes.Status404NotFound;
                context.Response.ContentType = "text/plain";
                return;
            }
 
            try
            {
                try
                {
                    await context.Request.Body.CopyToAsync(connection.ApplicationStream, bufferSize);
                }
                catch (InvalidOperationException ex)
                {
                    // PipeWriter will throw an error if it is written to while dispose is in progress and the writer has been completed
                    // Dispose isn't taking WriteLock because it could be held because of backpressure, and calling CancelPendingFlush
                    // then taking the lock introduces a race condition that could lead to a deadlock
                    Log.ConnectionDisposedWhileWriteInProgress(_logger, connection.ConnectionId, ex);
 
                    context.Response.StatusCode = StatusCodes.Status404NotFound;
                    context.Response.ContentType = "text/plain";
                    return;
                }
                catch (OperationCanceledException)
                {
                    // CancelPendingFlush has canceled pending writes caused by backpressure
                    Log.ConnectionDisposed(_logger, connection.ConnectionId);
 
                    context.Response.StatusCode = StatusCodes.Status404NotFound;
                    context.Response.ContentType = "text/plain";
 
                    // There are no writes anymore (since this is the write "loop")
                    // So it is safe to complete the writer
                    // We complete the writer here because we already have the WriteLock acquired
                    // and it's unsafe to complete outside of the lock
                    // Other code isn't guaranteed to be able to acquire the lock before another write
                    // even if CancelPendingFlush is called, and the other write could hang if there is backpressure
                    connection.Application.Output.Complete();
                    return;
                }
                catch (IOException ex)
                {
                    // Can occur when the HTTP request is canceled by the client
                    Log.FailedToReadHttpRequestBody(_logger, connection.ConnectionId, ex);
 
                    context.Response.StatusCode = StatusCodes.Status400BadRequest;
                    context.Response.ContentType = "text/plain";
                    return;
                }
 
                Log.ReceivedBytes(_logger, connection.ApplicationStream.Length);
            }
            finally
            {
                // Clear the amount of read bytes so logging is accurate
                connection.ApplicationStream.Reset();
            }
        }
        finally
        {
            connection.WriteLock.Release();
        }
    }
 
    private async Task ProcessDeleteAsync(HttpContext context)
    {
        var connection = await GetConnectionAsync(context);
        if (connection == null)
        {
            // No such connection, GetConnection already set the response status code
            return;
        }
 
        // This end point only works for long polling
        if (connection.TransportType != HttpTransportType.LongPolling)
        {
            Log.ReceivedDeleteRequestForUnsupportedTransport(_logger, connection.TransportType);
            context.Response.StatusCode = StatusCodes.Status400BadRequest;
            context.Response.ContentType = "text/plain";
            await context.Response.WriteAsync("Cannot terminate this connection using the DELETE endpoint.");
            return;
        }
 
        Log.TerminatingConnection(_logger);
 
        // Dispose the connection, but don't wait for it. We assign it here so we can wait in tests
        connection.DisposeAndRemoveTask = _manager.DisposeAndRemoveAsync(connection, closeGracefully: false, HttpConnectionStopStatus.NormalClosure);
 
        context.Response.StatusCode = StatusCodes.Status202Accepted;
        context.Response.ContentType = "text/plain";
    }
 
    private async Task<bool> EnsureConnectionStateAsync(HttpConnectionContext connection, HttpContext context, HttpTransportType transportType, HttpTransportType supportedTransports, ConnectionLogScope logScope)
    {
        if ((supportedTransports & transportType) == 0)
        {
            context.Response.ContentType = "text/plain";
            context.Response.StatusCode = StatusCodes.Status404NotFound;
            Log.TransportNotSupported(_logger, transportType);
            await context.Response.WriteAsync($"{transportType} transport not supported by this end point type");
            return false;
        }
 
        switch (connection.TrySetTransport(transportType, _metrics))
        {
            case HttpConnectionContext.SetTransportState.Success:
                break;
 
            case HttpConnectionContext.SetTransportState.AlreadyActive:
                Log.ConnectionAlreadyActive(_logger, connection.ConnectionId, context.TraceIdentifier);
 
                // Reject the request with a 409 conflict
                context.Response.StatusCode = StatusCodes.Status409Conflict;
                context.Response.ContentType = "text/plain";
                return false;
 
            case HttpConnectionContext.SetTransportState.CannotChange:
                context.Response.ContentType = "text/plain";
                context.Response.StatusCode = StatusCodes.Status400BadRequest;
                Log.CannotChangeTransport(_logger, connection.TransportType, transportType);
                await context.Response.WriteAsync("Cannot change transports mid-connection");
                return false;
        }
 
        // Set the IHttpConnectionFeature now that we can access it.
        connection.Features.Set(context.Features.Get<IHttpConnectionFeature>());
 
        // Configure transport-specific features.
        if (transportType == HttpTransportType.LongPolling)
        {
            connection.HasInherentKeepAlive = true;
 
            // For long polling, the requests come and go but the connection is still alive.
            // To make the IHttpContextFeature work well, we make a copy of the relevant properties
            // to a new HttpContext. This means that it's impossible to affect the context
            // with subsequent requests.
            var existing = connection.HttpContext;
            if (existing == null)
            {
                CloneHttpContext(context, connection);
            }
            else
            {
                // Set the request trace identifier to the current http request handling the poll
                existing.TraceIdentifier = context.TraceIdentifier;
 
                // Don't copy the identity if it's a windows identity
                // We specifically clone the identity on first poll if it's a windows identity
                // If we swapped the new User here we'd have to dispose the old identities which could race with the application
                // trying to access the identity.
                if (!(context.User.Identity is WindowsIdentity))
                {
                    existing.User = context.User;
                }
            }
        }
        else
        {
            connection.HttpContext = context;
        }
 
        if (connection.User is not null)
        {
            var originalName = connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
            var newName = connection.HttpContext?.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
            if (originalName != newName)
            {
                // Log warning, different user
                Log.UserNameChanged(_logger, originalName, newName);
            }
        }
 
        // Setup the connection state from the http context
        connection.User = connection.HttpContext?.User;
 
        UpdateExpiration(connection, context);
 
        // Set the Connection ID on the logging scope so that logs from now on will have the
        // Connection ID metadata set.
        logScope.ConnectionId = connection.ConnectionId;
 
        return true;
    }
 
    private static void UpdateExpiration(HttpConnectionContext connection, HttpContext context)
    {
        var authenticateResultFeature = context.Features.Get<IAuthenticateResultFeature>();
 
        if (authenticateResultFeature is not null)
        {
            connection.AuthenticationExpiration =
                authenticateResultFeature.AuthenticateResult?.Properties?.ExpiresUtc ?? DateTimeOffset.MaxValue;
        }
    }
 
    private static void CloneUser(HttpContext newContext, HttpContext oldContext)
    {
        // If the identity is a WindowsIdentity we need to clone the User.
        // This is because the WindowsIdentity uses SafeHandle's which are disposed at the end of the request
        // and accessing the identity can happen outside of the request scope.
        if (oldContext.User.Identity is WindowsIdentity windowsIdentity)
        {
            var skipFirstIdentity = false;
            if (OperatingSystem.IsWindows() && oldContext.User is WindowsPrincipal)
            {
                // We want to explicitly create a WindowsPrincipal instead of a ClaimsPrincipal
                // so methods that WindowsPrincipal overrides like 'IsInRole', work as expected.
                newContext.User = new WindowsPrincipal((WindowsIdentity)(windowsIdentity.Clone()));
                skipFirstIdentity = true;
            }
            else
            {
                newContext.User = new ClaimsPrincipal();
            }
 
            foreach (var identity in oldContext.User.Identities)
            {
                if (skipFirstIdentity)
                {
                    skipFirstIdentity = false;
                    continue;
                }
                newContext.User.AddIdentity(identity.Clone());
            }
        }
        else
        {
            newContext.User = oldContext.User;
        }
    }
 
    private static void CloneHttpContext(HttpContext context, HttpConnectionContext connection)
    {
        // The reason we're copying the base features instead of the HttpContext properties is
        // so that we can get all of the logic built into DefaultHttpContext to extract higher level
        // structure from the low level properties
        var existingRequestFeature = context.Features.GetRequiredFeature<IHttpRequestFeature>();
 
        var requestFeature = new HttpRequestFeature
        {
            Protocol = existingRequestFeature.Protocol,
            Method = existingRequestFeature.Method,
            Scheme = existingRequestFeature.Scheme,
            Path = existingRequestFeature.Path,
            PathBase = existingRequestFeature.PathBase,
            QueryString = existingRequestFeature.QueryString,
            RawTarget = existingRequestFeature.RawTarget
        };
        var requestHeaders = new Dictionary<string, StringValues>(existingRequestFeature.Headers.Count, StringComparer.OrdinalIgnoreCase);
        foreach (var header in existingRequestFeature.Headers)
        {
            requestHeaders[header.Key] = header.Value;
        }
        requestFeature.Headers = new HeaderDictionary(requestHeaders);
 
        var existingConnectionFeature = context.Features.Get<IHttpConnectionFeature>();
        var connectionFeature = new HttpConnectionFeature();
 
        if (existingConnectionFeature != null)
        {
            connectionFeature.ConnectionId = existingConnectionFeature.ConnectionId;
            connectionFeature.LocalIpAddress = existingConnectionFeature.LocalIpAddress;
            connectionFeature.LocalPort = existingConnectionFeature.LocalPort;
            connectionFeature.RemoteIpAddress = existingConnectionFeature.RemoteIpAddress;
            connectionFeature.RemotePort = existingConnectionFeature.RemotePort;
        }
 
        // The response is a dud, you can't do anything with it anyways
        var responseFeature = new HttpResponseFeature();
 
        var features = new FeatureCollection();
        features.Set<IHttpRequestFeature>(requestFeature);
        features.Set<IHttpResponseFeature>(responseFeature);
        features.Set<IHttpResponseBodyFeature>(new StreamResponseBodyFeature(Stream.Null));
        features.Set<IHttpConnectionFeature>(connectionFeature);
 
        // REVIEW: We could strategically look at adding other features but it might be better
        // if we expose a callback that would allow the user to preserve HttpContext properties.
 
        var newHttpContext = new DefaultHttpContext(features);
        newHttpContext.TraceIdentifier = context.TraceIdentifier;
 
        newHttpContext.SetEndpoint(context.GetEndpoint());
 
        CloneUser(newHttpContext, context);
 
        connection.ServiceScope = context.RequestServices.CreateAsyncScope();
        newHttpContext.RequestServices = connection.ServiceScope.Value.ServiceProvider;
 
        // REVIEW: This extends the lifetime of anything that got put into HttpContext.Items
        newHttpContext.Items = new Dictionary<object, object?>(context.Items);
 
        connection.HttpContext = newHttpContext;
    }
 
    private async Task<HttpConnectionContext?> GetConnectionAsync(HttpContext context)
    {
        var connectionToken = GetConnectionToken(context);
 
        if (StringValues.IsNullOrEmpty(connectionToken))
        {
            // There's no connection ID: bad request
            context.Response.StatusCode = StatusCodes.Status400BadRequest;
            context.Response.ContentType = "text/plain";
            await context.Response.WriteAsync("Connection ID required");
            return null;
        }
 
        // Use ToString; IsNullOrEmpty doesn't tell the compiler anything about implicit conversion to string.
        if (!_manager.TryGetConnection(connectionToken.ToString(), out var connection))
        {
            // No connection with that ID: Not Found
            context.Response.StatusCode = StatusCodes.Status404NotFound;
            context.Response.ContentType = "text/plain";
            await context.Response.WriteAsync("No Connection with that ID");
            return null;
        }
 
        return connection;
    }
 
    // This is only used for WebSockets connections, which can connect directly without negotiating
    private async Task<HttpConnectionContext?> GetOrCreateConnectionAsync(HttpContext context, HttpConnectionDispatcherOptions options)
    {
        var connectionToken = GetConnectionToken(context);
        HttpConnectionContext? connection;
 
        // There's no connection id so this is a brand new connection
        if (StringValues.IsNullOrEmpty(connectionToken))
        {
            connection = CreateConnection(options);
        }
        // Use ToString; IsNullOrEmpty doesn't tell the compiler anything about implicit conversion to string.
        else if (!_manager.TryGetConnection(connectionToken.ToString(), out connection))
        {
            // No connection with that ID: Not Found
            context.Response.StatusCode = StatusCodes.Status404NotFound;
            await context.Response.WriteAsync("No Connection with that ID");
            return null;
        }
 
        return connection;
    }
 
    private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int clientProtocolVersion = 0, bool useStatefulReconnect = false)
    {
        return _manager.CreateConnection(options, clientProtocolVersion, useStatefulReconnect);
    }
 
    private static void AddNoCacheHeaders(HttpResponse response)
    {
        response.Headers.CacheControl = HeaderValueNoCacheNoStore;
        response.Headers.Pragma = HeaderValueNoCache;
        response.Headers.Expires = HeaderValueEpochDate;
    }
 
    private sealed class EmptyServiceProvider : IServiceProvider
    {
        public static EmptyServiceProvider Instance { get; } = new EmptyServiceProvider();
        public object? GetService(Type serviceType) => null;
    }
}