File: System\Net\Security\SslSessionsCache.cs
Web Access
Project: src\src\libraries\System.Net.Security\src\System.Net.Security.csproj (System.Net.Security)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
 
namespace System.Net.Security
{
    // Implements SSL session caching mechanism based on a static table of SSL credentials.
    internal static class SslSessionsCache
    {
        private const int CheckExpiredModulo = 32;
        private static readonly ConcurrentDictionary<SslCredKey, SafeCredentialReference> s_cachedCreds =
            new ConcurrentDictionary<SslCredKey, SafeCredentialReference>();
 
        //
        // Uses certificate thumb-print comparison.
        //
        private readonly struct SslCredKey : IEquatable<SslCredKey>
        {
            private readonly byte[] _thumbPrint;
            private readonly int _allowedProtocols;
            private readonly EncryptionPolicy _encryptionPolicy;
            private readonly bool _isServerMode;
            private readonly bool _sendTrustList;
            private readonly bool _checkRevocation;
            private readonly bool _allowTlsResume;
 
            //
            // SECURITY: X509Certificate.GetCertHash() is virtual hence before going here,
            //           the caller of this ctor has to ensure that a user cert object was inspected and
            //           optionally cloned.
            //
            internal SslCredKey(
                byte[]? thumbPrint,
                int allowedProtocols,
                bool isServerMode,
                EncryptionPolicy encryptionPolicy,
                bool sendTrustList,
                bool checkRevocation,
                bool allowTlsResume)
            {
                _thumbPrint = thumbPrint ?? Array.Empty<byte>();
                _allowedProtocols = allowedProtocols;
                _encryptionPolicy = encryptionPolicy;
                _isServerMode = isServerMode;
                _checkRevocation = checkRevocation;
                _sendTrustList = sendTrustList;
                _allowTlsResume = allowTlsResume;
            }
 
            public override int GetHashCode()
            {
                int hashCode = 0;
                if (_thumbPrint.Length > 3)
                {
                    hashCode ^= _thumbPrint[0] | (_thumbPrint[1] << 8) | (_thumbPrint[2] << 16) | (_thumbPrint[3] << 24);
                }
 
                return HashCode.Combine(_allowedProtocols,
                                        (int)_encryptionPolicy,
                                        _isServerMode,
                                        _sendTrustList,
                                        _checkRevocation,
                                        _allowedProtocols,
                                        hashCode);
            }
 
            public override bool Equals([NotNullWhen(true)] object? obj) =>
                obj is SslCredKey other && Equals(other);
 
            public bool Equals(SslCredKey other)
            {
                byte[] thumbPrint = _thumbPrint;
                byte[] otherThumbPrint = other._thumbPrint;
 
                return
                    thumbPrint.Length == otherThumbPrint.Length &&
                    _encryptionPolicy == other._encryptionPolicy &&
                    _allowedProtocols == other._allowedProtocols &&
                    _isServerMode == other._isServerMode &&
                    _sendTrustList == other._sendTrustList &&
                    _checkRevocation == other._checkRevocation &&
                    _allowTlsResume == other._allowTlsResume &&
                    thumbPrint.AsSpan().SequenceEqual(otherThumbPrint);
            }
        }
 
        //
        // Returns null or previously cached cred handle.
        //
        // ATTN: The returned handle can be invalid, the callers of InitializeSecurityContext and AcceptSecurityContext
        // must be prepared to execute a back-out code if the call fails.
        //
        internal static SafeFreeCredentials? TryCachedCredential(
            byte[]? thumbPrint,
            SslProtocols sslProtocols,
            bool isServer,
            EncryptionPolicy encryptionPolicy,
            bool checkRevocation,
            bool allowTlsResume,
            bool sendTrustList)
        {
            if (s_cachedCreds.IsEmpty)
            {
                if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"Not found, Current Cache Count = {s_cachedCreds.Count}");
                return null;
            }
 
            var key = new SslCredKey(thumbPrint, (int)sslProtocols, isServer, encryptionPolicy, sendTrustList, checkRevocation, allowTlsResume);
 
            //SafeCredentialReference? cached;
            SafeFreeCredentials? credentials = GetCachedCredential(key);
            if (credentials == null || credentials.IsClosed || credentials.IsInvalid || credentials.Expiry < DateTime.UtcNow)
            {
                if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"Not found or invalid, Current Cache Count = {s_cachedCreds.Count}");
                return null;
            }
 
            if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"Found a cached Handle = {credentials}");
 
            return credentials;
        }
 
        private static SafeFreeCredentials? GetCachedCredential(SslCredKey key)
        {
            return s_cachedCreds.TryGetValue(key, out SafeCredentialReference? cached) ? cached.Target : null;
        }
 
        //
        // The app is calling this method after starting an SSL handshake.
        //
        // ATTN: The thumbPrint must be from inspected and possibly cloned user Cert object or we get a security hole in SslCredKey ctor.
        //
        internal static void CacheCredential(
            SafeFreeCredentials creds,
            byte[]? thumbPrint,
            SslProtocols sslProtocols,
            bool isServer,
            EncryptionPolicy encryptionPolicy,
            bool checkRevocation,
            bool allowTlsResume,
            bool sendTrustList)
        {
            Debug.Assert(creds != null, "creds == null");
 
            if (creds.IsInvalid)
            {
                if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"Refused to cache an Invalid Handle {creds}, Current Cache Count = {s_cachedCreds.Count}");
                return;
            }
 
            SslCredKey key = new SslCredKey(thumbPrint, (int)sslProtocols, isServer, encryptionPolicy, sendTrustList, checkRevocation, allowTlsResume);
 
            SafeFreeCredentials? credentials = GetCachedCredential(key);
 
            DateTime utcNow = DateTime.UtcNow;
            if (credentials == null || credentials.IsClosed || credentials.IsInvalid || credentials.Expiry < utcNow)
            {
                lock (s_cachedCreds)
                {
                    credentials = GetCachedCredential(key);
                    if (credentials == null || credentials.IsClosed || credentials.IsInvalid || credentials.Expiry < utcNow)
                    {
                        SafeCredentialReference? cached = SafeCredentialReference.CreateReference(creds);
 
                        if (cached == null)
                        {
                            // Means the handle got closed in between, return it back and let caller deal with the issue.
                            return;
                        }
 
                        s_cachedCreds[key] = cached;
                        if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"Caching New Handle = {creds}, Current Cache Count = {s_cachedCreds.Count}");
 
                        ShrinkCredentialCache();
 
                    }
                    else
                    {
                        if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"CacheCredential() (locked retry) Found already cached Handle = {credentials}");
                    }
                }
            }
            else
            {
                if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"CacheCredential() Ignoring incoming handle = {creds} since found already cached Handle = {credentials}");
            }
 
            static void ShrinkCredentialCache()
            {
 
                //
                // A simplest way of preventing infinite cache grows.
                //
                // Security relief (DoS):
                //     A number of active creds is never greater than a number of _outstanding_
                //     security sessions, i.e. SSL connections.
                //     So we will try to shrink cache to the number of active creds once in a while.
                //
                //    We won't shrink cache in the case when NO new handles are coming to it.
                //
                if ((s_cachedCreds.Count % CheckExpiredModulo) == 0)
                {
                    KeyValuePair<SslCredKey, SafeCredentialReference>[] toRemoveAttempt = s_cachedCreds.ToArray();
 
                    for (int i = 0; i < toRemoveAttempt.Length; ++i)
                    {
                        SafeCredentialReference? cached = toRemoveAttempt[i].Value;
                        SafeFreeCredentials? creds = cached.Target;
 
                        if (creds == null)
                        {
                            s_cachedCreds.TryRemove(toRemoveAttempt[i].Key, out _);
                            continue;
                        }
 
                        cached.Dispose();
                        cached = SafeCredentialReference.CreateReference(creds);
                        if (cached != null)
                        {
                            s_cachedCreds[toRemoveAttempt[i].Key] = cached;
                        }
                        else
                        {
                            s_cachedCreds.TryRemove(toRemoveAttempt[i].Key, out _);
                        }
 
                    }
                    if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"Scavenged cache, New Cache Count = {s_cachedCreds.Count}");
                }
            }
        }
    }
}