File: Internal\SniOptionsSelector.cs
Web Access
Project: src\src\Servers\Kestrel\Core\src\Microsoft.AspNetCore.Server.Kestrel.Core.csproj (Microsoft.AspNetCore.Server.Kestrel.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Linq;
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Certificates;
using Microsoft.AspNetCore.Server.Kestrel.Https;
using Microsoft.AspNetCore.Server.Kestrel.Https.Internal;
using Microsoft.Extensions.Logging;
 
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
 
internal sealed class SniOptionsSelector
{
    private const string WildcardHost = "*";
    private const string WildcardPrefix = "*.";
 
    private readonly string _endpointName;
    private readonly ILogger<HttpsConnectionMiddleware> _logger;
 
    private readonly Func<ConnectionContext, string?, X509Certificate2?>? _fallbackServerCertificateSelector;
    private readonly Action<ConnectionContext, SslServerAuthenticationOptions>? _onAuthenticateCallback;
 
    private readonly Dictionary<string, SniOptions> _exactNameOptions = new Dictionary<string, SniOptions>(StringComparer.OrdinalIgnoreCase);
    private readonly SortedList<string, SniOptions> _wildcardPrefixOptions = new SortedList<string, SniOptions>(LongestStringFirstComparer.Instance);
    private readonly SniOptions? _wildcardOptions;
 
    public SniOptionsSelector(
        string endpointName,
        Dictionary<string, SniConfig> sniDictionary,
        ICertificateConfigLoader certifcateConfigLoader,
        HttpsConnectionAdapterOptions fallbackHttpsOptions,
        HttpProtocols fallbackHttpProtocols,
        ILogger<HttpsConnectionMiddleware> logger)
    {
        _endpointName = endpointName;
        _logger = logger;
 
        _fallbackServerCertificateSelector = fallbackHttpsOptions.ServerCertificateSelector;
        _onAuthenticateCallback = fallbackHttpsOptions.OnAuthenticate;
 
        foreach (var (name, sniConfig) in sniDictionary)
        {
            var (serverCert, fullChain) = certifcateConfigLoader.LoadCertificate(sniConfig.Certificate, $"{endpointName}:Sni:{name}");
            var sslOptions = new SslServerAuthenticationOptions
            {
 
                ServerCertificate = serverCert,
                EnabledSslProtocols = sniConfig.SslProtocols ?? fallbackHttpsOptions.SslProtocols,
                CertificateRevocationCheckMode = fallbackHttpsOptions.CheckCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck,
            };
 
            if (sslOptions.ServerCertificate is null)
            {
                if (!fallbackHttpsOptions.HasServerCertificateOrSelector)
                {
                    throw new InvalidOperationException(CoreStrings.NoCertSpecifiedNoDevelopmentCertificateFound);
                }
 
                if (_fallbackServerCertificateSelector is null)
                {
                    // Cache the fallback ServerCertificate since there's no fallback ServerCertificateSelector taking precedence.
                    sslOptions.ServerCertificate = fallbackHttpsOptions.ServerCertificate;
                }
            }
 
            if (sslOptions.ServerCertificate != null)
            {
                // This might be do blocking IO but it'll resolve the certificate chain up front before any connections are
                // made to the server
                sslOptions.ServerCertificateContext = SslStreamCertificateContext.Create((X509Certificate2)sslOptions.ServerCertificate, additionalCertificates: fullChain);
            }
 
            if (!certifcateConfigLoader.IsTestMock && sslOptions.ServerCertificate is X509Certificate2 cert2)
            {
                HttpsConnectionMiddleware.EnsureCertificateIsAllowedForServerAuth(cert2, logger);
            }
 
            var clientCertificateMode = sniConfig.ClientCertificateMode ?? fallbackHttpsOptions.ClientCertificateMode;
 
            if (clientCertificateMode != ClientCertificateMode.NoCertificate)
            {
                sslOptions.ClientCertificateRequired = clientCertificateMode == ClientCertificateMode.AllowCertificate
                    || clientCertificateMode == ClientCertificateMode.RequireCertificate;
                sslOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) =>
                    HttpsConnectionMiddleware.RemoteCertificateValidationCallback(
                        clientCertificateMode, fallbackHttpsOptions.ClientCertificateValidation, certificate, chain, sslPolicyErrors);
            }
 
            var httpProtocols = sniConfig.Protocols ?? fallbackHttpProtocols;
            httpProtocols = HttpsConnectionMiddleware.ValidateAndNormalizeHttpProtocols(httpProtocols, logger);
            HttpsConnectionMiddleware.ConfigureAlpn(sslOptions, httpProtocols);
 
            var sniOptions = new SniOptions(sslOptions, httpProtocols, clientCertificateMode);
 
            if (name.Equals(WildcardHost, StringComparison.Ordinal))
            {
                _wildcardOptions = sniOptions;
            }
            else if (name.StartsWith(WildcardPrefix, StringComparison.Ordinal))
            {
                // Only slice off 1 character, the `*`. We want to match the leading `.` also.
                _wildcardPrefixOptions.Add(name.Substring(1), sniOptions);
            }
            else
            {
                _exactNameOptions.Add(name, sniOptions);
            }
        }
    }
 
    public (SslServerAuthenticationOptions, ClientCertificateMode) GetOptions(ConnectionContext connection, string serverName)
    {
        SniOptions? sniOptions = null;
 
        if (!string.IsNullOrEmpty(serverName) && !_exactNameOptions.TryGetValue(serverName, out sniOptions))
        {
            foreach (var (suffix, options) in _wildcardPrefixOptions)
            {
                if (serverName.EndsWith(suffix, StringComparison.OrdinalIgnoreCase))
                {
                    sniOptions = options;
                    break;
                }
            }
        }
 
        // Fully wildcarded ("*") options can be used even when given an empty server name.
        sniOptions ??= _wildcardOptions;
 
        if (sniOptions is null)
        {
            if (serverName is null)
            {
                // There was no ALPN
                throw new AuthenticationException(CoreStrings.FormatSniNotConfiguredToAllowNoServerName(_endpointName));
            }
            else
            {
                throw new AuthenticationException(CoreStrings.FormatSniNotConfiguredForServerName(serverName, _endpointName));
            }
        }
 
        connection.Features.Set(new HttpProtocolsFeature(sniOptions.HttpProtocols));
 
        var sslOptions = sniOptions.SslOptions;
 
        if (sslOptions.ServerCertificate is null)
        {
            Debug.Assert(_fallbackServerCertificateSelector != null,
                "The cached SniOptions ServerCertificate can only be null if there's a fallback certificate selector.");
 
            // If a ServerCertificateSelector doesn't return a cert, HttpsConnectionMiddleware doesn't fallback to the ServerCertificate.
            sslOptions = CloneSslOptions(sslOptions);
            var fallbackCertificate = _fallbackServerCertificateSelector(connection, serverName);
 
            if (fallbackCertificate != null)
            {
                HttpsConnectionMiddleware.EnsureCertificateIsAllowedForServerAuth(fallbackCertificate, _logger);
            }
 
            sslOptions.ServerCertificate = fallbackCertificate;
        }
 
        if (_onAuthenticateCallback != null)
        {
            // From doc comments: "This is called after all of the other settings have already been applied."
            sslOptions = CloneSslOptions(sslOptions);
            _onAuthenticateCallback(connection, sslOptions);
        }
 
        return (sslOptions, sniOptions.ClientCertificateMode);
    }
 
    public static ValueTask<SslServerAuthenticationOptions> OptionsCallback(TlsHandshakeCallbackContext callbackContext)
    {
        var sniOptionsSelector = (SniOptionsSelector)callbackContext.State!;
        var (options, clientCertificateMode) = sniOptionsSelector.GetOptions(callbackContext.Connection, callbackContext.ClientHelloInfo.ServerName);
        callbackContext.AllowDelayedClientCertificateNegotation = clientCertificateMode == ClientCertificateMode.DelayCertificate;
        return new ValueTask<SslServerAuthenticationOptions>(options);
    }
 
    internal static SslServerAuthenticationOptions CloneSslOptions(SslServerAuthenticationOptions sslOptions) =>
        new SslServerAuthenticationOptions
        {
            AllowRenegotiation = sslOptions.AllowRenegotiation,
            ApplicationProtocols = sslOptions.ApplicationProtocols?.ToList(),
            CertificateRevocationCheckMode = sslOptions.CertificateRevocationCheckMode,
            CipherSuitesPolicy = sslOptions.CipherSuitesPolicy,
            ClientCertificateRequired = sslOptions.ClientCertificateRequired,
            EnabledSslProtocols = sslOptions.EnabledSslProtocols,
            EncryptionPolicy = sslOptions.EncryptionPolicy,
            RemoteCertificateValidationCallback = sslOptions.RemoteCertificateValidationCallback,
            ServerCertificate = sslOptions.ServerCertificate,
            ServerCertificateContext = sslOptions.ServerCertificateContext,
            ServerCertificateSelectionCallback = sslOptions.ServerCertificateSelectionCallback,
            CertificateChainPolicy = sslOptions.CertificateChainPolicy,
            AllowTlsResume = sslOptions.AllowTlsResume,
        };
 
    private sealed class SniOptions
    {
        public SniOptions(SslServerAuthenticationOptions sslOptions, HttpProtocols httpProtocols, ClientCertificateMode clientCertificateMode)
        {
            SslOptions = sslOptions;
            HttpProtocols = httpProtocols;
            ClientCertificateMode = clientCertificateMode;
        }
 
        public SslServerAuthenticationOptions SslOptions { get; }
        public HttpProtocols HttpProtocols { get; }
        public ClientCertificateMode ClientCertificateMode { get; }
    }
 
    private sealed class LongestStringFirstComparer : IComparer<string>
    {
        public static LongestStringFirstComparer Instance { get; } = new LongestStringFirstComparer();
 
        private LongestStringFirstComparer()
        {
        }
 
        public int Compare(string? x, string? y)
        {
            // Flip x and y to put the longest instead of the shortest string first in the SortedList.
            var lengthResult = y!.Length.CompareTo(x!.Length);
            if (lengthResult != 0)
            {
                return lengthResult;
            }
            else
            {
                return string.Compare(x, y, StringComparison.OrdinalIgnoreCase);
            }
        }
    }
}