File: WebSocketMiddleware.cs
Web Access
Project: src\src\Middleware\WebSockets\src\Microsoft.AspNetCore.WebSockets.csproj (Microsoft.AspNetCore.WebSockets)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Linq;
using System.Net.WebSockets;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Timeouts;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;
 
namespace Microsoft.AspNetCore.WebSockets;
 
/// <summary>
/// Enables accepting WebSocket requests by adding a <see cref="IHttpWebSocketFeature"/>
/// to the <see cref="HttpContext"/> if the request is a valid WebSocket request.
/// </summary>
public partial class WebSocketMiddleware
{
    private readonly RequestDelegate _next;
    private readonly WebSocketOptions _options;
    private readonly ILogger _logger;
    private readonly bool _anyOriginAllowed;
    private readonly List<string> _allowedOrigins;
 
    /// <summary>
    /// Creates a new instance of the <see cref="WebSocketMiddleware"/>.
    /// </summary>
    /// <param name="next">The next middleware in the pipeline.</param>
    /// <param name="options">The configuration options.</param>
    /// <param name="loggerFactory">An <see cref="ILoggerFactory"/> instance used to create loggers.</param>
    public WebSocketMiddleware(RequestDelegate next, IOptions<WebSocketOptions> options, ILoggerFactory loggerFactory)
    {
        ArgumentNullException.ThrowIfNull(next);
        ArgumentNullException.ThrowIfNull(options);
 
        _next = next;
        _options = options.Value;
        _allowedOrigins = _options.AllowedOrigins.Select(o => o.ToLowerInvariant()).ToList();
        _anyOriginAllowed = _options.AllowedOrigins.Count == 0 || _options.AllowedOrigins.Contains("*", StringComparer.Ordinal);
 
        _logger = loggerFactory.CreateLogger<WebSocketMiddleware>();
 
        // TODO: validate options.
    }
 
    /// <summary>
    /// Processes a request to determine if it is a WebSocket request, and if so,
    /// sets the <see cref="IHttpWebSocketFeature"/> on the <see cref="HttpContext.Features"/>.
    /// </summary>
    /// <param name="context">The <see cref="HttpContext"/> representing the request.</param>
    /// <returns>The <see cref="Task"/> that represents the completion of the middleware pipeline.</returns>
    public Task Invoke(HttpContext context)
    {
        // Detect if an opaque upgrade is available. If so, add a websocket upgrade.
        var upgradeFeature = context.Features.Get<IHttpUpgradeFeature>();
        var connectFeature = context.Features.Get<IHttpExtendedConnectFeature>();
        if ((upgradeFeature != null || connectFeature != null) && context.Features.Get<IHttpWebSocketFeature>() == null)
        {
            var webSocketFeature = new WebSocketHandshake(context, upgradeFeature, connectFeature, _options, _logger);
            context.Features.Set<IHttpWebSocketFeature>(webSocketFeature);
            if (!_anyOriginAllowed)
            {
                // Check for Origin header
                var originHeader = context.Request.Headers.Origin;
 
                if (!StringValues.IsNullOrEmpty(originHeader) && webSocketFeature.IsWebSocketRequest)
                {
                    // Check allowed origins to see if request is allowed
                    if (!_allowedOrigins.Contains(originHeader.ToString(), StringComparer.Ordinal))
                    {
                        if (_logger.IsEnabled(LogLevel.Debug))
                        {
                            _logger.LogDebug("Request origin {Origin} is not in the list of allowed origins.", originHeader.ToString());
                        }
                        context.Response.StatusCode = StatusCodes.Status403Forbidden;
                        return Task.CompletedTask;
                    }
                }
            }
        }
 
        return _next(context);
    }
 
    private sealed class WebSocketHandshake : IHttpWebSocketFeature
    {
        private readonly HttpContext _context;
        private readonly IHttpUpgradeFeature? _upgradeFeature;
        private readonly IHttpExtendedConnectFeature? _connectFeature;
        private readonly WebSocketOptions _options;
        private readonly ILogger _logger;
        private bool? _isWebSocketRequest;
        private bool _isH2WebSocket;
 
        public WebSocketHandshake(HttpContext context, IHttpUpgradeFeature? upgradeFeature, IHttpExtendedConnectFeature? connectFeature, WebSocketOptions options, ILogger logger)
        {
            _context = context;
            _upgradeFeature = upgradeFeature;
            _connectFeature = connectFeature;
            _options = options;
            _logger = logger;
        }
 
        public bool IsWebSocketRequest
        {
            get
            {
                if (_isWebSocketRequest == null)
                {
                    if (_connectFeature?.IsExtendedConnect == true)
                    {
                        _isH2WebSocket = CheckSupportedWebSocketRequestH2(_context.Request.Method, _connectFeature.Protocol, _context.Request.Headers);
                        _isWebSocketRequest = _isH2WebSocket;
                    }
                    else if (_upgradeFeature?.IsUpgradableRequest == true)
                    {
                        _isWebSocketRequest = CheckSupportedWebSocketRequest(_context.Request.Method, _context.Request.Headers);
                    }
                    else
                    {
                        _isWebSocketRequest = false;
                    }
                }
                return _isWebSocketRequest.Value;
            }
        }
 
        public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
        {
            if (!IsWebSocketRequest)
            {
                throw new InvalidOperationException("Not a WebSocket request.");
            }
 
            string? subProtocol = null;
            bool enableCompression = false;
            bool serverContextTakeover = true;
            int serverMaxWindowBits = 15;
            TimeSpan keepAliveInterval = _options.KeepAliveInterval;
            TimeSpan keepAliveTimeout = _options.KeepAliveTimeout;
            if (acceptContext != null)
            {
                subProtocol = acceptContext.SubProtocol;
                enableCompression = acceptContext.DangerousEnableCompression;
                serverContextTakeover = !acceptContext.DisableServerContextTakeover;
                serverMaxWindowBits = acceptContext.ServerMaxWindowBits;
                keepAliveInterval = acceptContext.KeepAliveInterval ?? keepAliveInterval;
                keepAliveTimeout = acceptContext.KeepAliveTimeout ?? keepAliveTimeout;
            }
 
#pragma warning disable CS0618 // Type or member is obsolete
            if (acceptContext is ExtendedWebSocketAcceptContext advancedAcceptContext)
#pragma warning restore CS0618 // Type or member is obsolete
            {
                if (advancedAcceptContext.KeepAliveInterval.HasValue)
                {
                    keepAliveInterval = advancedAcceptContext.KeepAliveInterval.Value;
                }
            }
 
            HandshakeHelpers.GenerateResponseHeaders(!_isH2WebSocket, _context.Request.Headers, subProtocol, _context.Response.Headers);
 
            WebSocketDeflateOptions? deflateOptions = null;
            if (enableCompression)
            {
                var ext = _context.Request.Headers.SecWebSocketExtensions;
                if (ext.Count != 0)
                {
                    // loop over each extension offer, extensions can have multiple offers, we can accept any
                    foreach (var extension in _context.Request.Headers.GetCommaSeparatedValues(HeaderNames.SecWebSocketExtensions))
                    {
                        if (extension.AsSpan().TrimStart().StartsWith("permessage-deflate", StringComparison.Ordinal))
                        {
                            if (HandshakeHelpers.ParseDeflateOptions(extension.AsSpan().TrimStart(), serverContextTakeover, serverMaxWindowBits, out var parsedOptions, out var response))
                            {
                                Log.CompressionAccepted(_logger, response);
                                deflateOptions = parsedOptions;
                                // If more extension types are added, this would need to be a header append
                                // and we wouldn't want to break out of the loop
                                _context.Response.Headers.SecWebSocketExtensions = response;
                                break;
                            }
                        }
                    }
 
                    if (deflateOptions is null)
                    {
                        Log.CompressionNotAccepted(_logger);
                    }
                }
            }
 
            Stream opaqueTransport;
            // HTTP/2
            if (_isH2WebSocket)
            {
                // Send the response headers
                opaqueTransport = await _connectFeature!.AcceptAsync();
            }
            // HTTP/1.1
            else
            {
                opaqueTransport = await _upgradeFeature!.UpgradeAsync(); // Sets status code to 101
            }
 
            // Disable request timeout, if there is one, after the websocket has been accepted
            _context.Features.Get<IHttpRequestTimeoutFeature>()?.DisableTimeout();
 
            var abortStream = new AbortStream(_context, opaqueTransport);
            var wrappedSocket = WebSocket.CreateFromStream(abortStream, new WebSocketCreationOptions()
            {
                IsServer = true,
                KeepAliveInterval = keepAliveInterval,
                KeepAliveTimeout = keepAliveTimeout,
                SubProtocol = subProtocol,
                DangerousDeflateOptions = deflateOptions
            });
 
            abortStream.WebSocket = wrappedSocket;
            return wrappedSocket;
        }
 
        public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
        {
            if (!HttpMethods.IsGet(method))
            {
                return false;
            }
 
            if (!CheckWebSocketVersion(requestHeaders))
            {
                return false;
            }
 
            var foundHeader = false;
 
            var values = requestHeaders.GetCommaSeparatedValues(HeaderNames.Upgrade);
            foreach (var value in values)
            {
                if (string.Equals(value, Constants.Headers.UpgradeWebSocket, StringComparison.OrdinalIgnoreCase))
                {
                    // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
                    if (values.Length == 1)
                    {
                        requestHeaders.Upgrade = Constants.Headers.UpgradeWebSocket;
                    }
                    foundHeader = true;
                    break;
                }
            }
            if (!foundHeader)
            {
                return false;
            }
            foundHeader = false;
 
            values = requestHeaders.GetCommaSeparatedValues(HeaderNames.Connection);
            foreach (var value in values)
            {
                if (string.Equals(value, HeaderNames.Upgrade, StringComparison.OrdinalIgnoreCase))
                {
                    // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
                    if (values.Length == 1)
                    {
                        requestHeaders.Connection = HeaderNames.Upgrade;
                    }
                    foundHeader = true;
                    break;
                }
            }
            if (!foundHeader)
            {
                return false;
            }
 
            return HandshakeHelpers.IsRequestKeyValid(requestHeaders.SecWebSocketKey.ToString());
        }
 
        // https://datatracker.ietf.org/doc/html/rfc8441
        // :method = CONNECT
        // :protocol = websocket
        // :scheme = https
        // :path = /chat
        // :authority = server.example.com
        // sec-websocket-protocol = chat, superchat
        // sec-websocket-extensions = permessage-deflate
        // sec-websocket-version = 13
        // origin = http://www.example.com
        public static bool CheckSupportedWebSocketRequestH2(string method, string? protocol, IHeaderDictionary requestHeaders)
        {
            return HttpMethods.IsConnect(method)
                && string.Equals(protocol, Constants.Headers.UpgradeWebSocket, StringComparison.OrdinalIgnoreCase)
                && CheckWebSocketVersion(requestHeaders);
        }
 
        public static bool CheckWebSocketVersion(IHeaderDictionary requestHeaders)
        {
            var values = requestHeaders.GetCommaSeparatedValues(HeaderNames.SecWebSocketVersion);
            foreach (var value in values)
            {
                if (string.Equals(value, Constants.Headers.SupportedVersion, StringComparison.OrdinalIgnoreCase))
                {
                    // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
                    if (values.Length == 1)
                    {
                        requestHeaders.SecWebSocketVersion = Constants.Headers.SupportedVersion;
                    }
                    return true;
                }
            }
            return false;
        }
    }
 
    private static partial class Log
    {
        [LoggerMessage(1, LogLevel.Debug, "WebSocket compression negotiation accepted with values '{CompressionResponse}'.", EventName = "CompressionAccepted")]
        public static partial void CompressionAccepted(ILogger logger, string compressionResponse);
 
        [LoggerMessage(2, LogLevel.Debug, "Compression negotiation not accepted by server.", EventName = "CompressionNotAccepted")]
        public static partial void CompressionNotAccepted(ILogger logger);
    }
}