File: TestCertificateManager.cs
Web Access
Project: src\src\Shared\test\Shared.Tests\Microsoft.AspNetCore.Shared.Tests.csproj (Microsoft.AspNetCore.Shared.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using Microsoft.AspNetCore.Certificates.Generation;
 
namespace Microsoft.AspNetCore.Internal.Tests;
 
internal sealed class TestCertificateManager : CertificateManager
{
    private readonly Dictionary<StoreKey, List<InMemoryCertificateEntry>> _stores = new();
    private readonly Dictionary<string, byte[]> _createdFiles = new(StringComparer.OrdinalIgnoreCase);
    private readonly HashSet<string> _createdDirectories = new(StringComparer.OrdinalIgnoreCase);
    private readonly List<X509Certificate2> _correctedCertificates = new();
    private readonly List<X509Certificate2> _trustedCertificates = new();
    private readonly Dictionary<string, bool> _exportableByThumbprint = new(StringComparer.OrdinalIgnoreCase);
 
    public TestCertificateManager(
        IDictionary<StoreKey, IEnumerable<X509Certificate2>>? initialStores = null,
        string? subject = null,
        int? generatedVersion = null,
        int? minimumVersion = null)
        : base(
            subject ?? LocalhostHttpsDistinguishedName,
            generatedVersion ?? CurrentAspNetCoreCertificateVersion,
            minimumVersion ?? CurrentMinimumAspNetCoreCertificateVersion)
    {
        if (initialStores is null)
        {
            return;
        }
 
        foreach (var (storeKey, certificates) in initialStores)
        {
            foreach (var certificate in certificates)
            {
                AddCertificate(storeKey.StoreName, storeKey.StoreLocation, certificate);
            }
        }
    }
 
    public TrustLevel TrustResult { get; set; } = TrustLevel.Full;
 
    public Func<X509Certificate2, CheckCertificateStateResult>? CheckCertificateStateOverride { get; set; }
 
    public IReadOnlyDictionary<string, byte[]> CreatedFiles => _createdFiles;
 
    public IReadOnlyCollection<string> CreatedDirectories => _createdDirectories;
 
    public IReadOnlyList<X509Certificate2> CorrectedCertificates => _correctedCertificates;
 
    public IReadOnlyList<X509Certificate2> TrustedCertificates => _trustedCertificates;
 
    public void AddCertificate(StoreName storeName, StoreLocation storeLocation, X509Certificate2 certificate, bool? isExportable = null)
    {
        var entry = new InMemoryCertificateEntry(certificate, isExportable);
        var store = GetOrCreateStore(storeName, storeLocation);
        store.Add(entry);
        if (!string.IsNullOrEmpty(entry.Thumbprint))
        {
            _exportableByThumbprint[entry.Thumbprint] = entry.Exportable;
        }
    }
 
    public IReadOnlyList<X509Certificate2> GetStoreCertificates(StoreName storeName, StoreLocation storeLocation)
    {
        if (!_stores.TryGetValue(new StoreKey(storeName, storeLocation), out var entries))
        {
            return Array.Empty<X509Certificate2>();
        }
 
        return entries.Select(entry => entry.CreateCertificate()).ToArray();
    }
 
    public void RemoveStoreCertificates(StoreName storeName, StoreLocation storeLocation)
    {
        if (_stores.Remove(new StoreKey(storeName, storeLocation), out var entries))
        {
            foreach (var entry in entries)
            {
                if (!string.IsNullOrEmpty(entry.Thumbprint))
                {
                    _exportableByThumbprint.Remove(entry.Thumbprint);
                }
            }
        }
    }
 
    public bool TryGetCreatedFile(string path, out byte[] bytes)
    {
        if (_createdFiles.TryGetValue(path, out var value))
        {
            bytes = (byte[])value.Clone();
            return true;
        }
 
        bytes = Array.Empty<byte>();
        return false;
    }
 
    public X509Certificate2 CreateDevelopmentCertificateWithVersion(int generatedVersion, DateTimeOffset notBefore, DateTimeOffset notAfter)
    {
        var previousVersion = AspNetHttpsCertificateVersion;
        var previousMinimumVersion = MinimumAspNetHttpsCertificateVersion;
        if (generatedVersion < MinimumAspNetHttpsCertificateVersion)
        {
            MinimumAspNetHttpsCertificateVersion = generatedVersion;
        }
 
        AspNetHttpsCertificateVersion = generatedVersion;
        try
        {
            return CreateAspNetCoreHttpsDevelopmentCertificate(notBefore, notAfter);
        }
        finally
        {
            AspNetHttpsCertificateVersion = previousVersion;
            MinimumAspNetHttpsCertificateVersion = previousMinimumVersion;
        }
    }
 
    internal void ExportCertificateToMemory(X509Certificate2 certificate, string path, bool includePrivateKey, string? password, CertificateKeyExportFormat format)
    {
        if (Log.IsEnabled())
        {
            Log.ExportCertificateStart(GetDescription(certificate), path, includePrivateKey);
        }
 
        if (includePrivateKey && password is null)
        {
            Log.NoPasswordForCertificate();
        }
 
        var targetDirectoryPath = Path.GetDirectoryName(path);
        if (!string.IsNullOrEmpty(targetDirectoryPath))
        {
            Log.CreateExportCertificateDirectory(targetDirectoryPath);
            CreateDirectoryWithPermissions(targetDirectoryPath);
        }
 
        byte[] bytes;
        byte[] keyBytes;
        byte[]? pemEnvelope = null;
        RSA? key = null;
 
        try
        {
            if (includePrivateKey)
            {
                switch (format)
                {
                    case CertificateKeyExportFormat.Pfx:
                        bytes = certificate.Export(X509ContentType.Pkcs12, password);
                        break;
                    case CertificateKeyExportFormat.Pem:
                        key = certificate.GetRSAPrivateKey()!;
 
                        char[] pem;
                        if (password != null)
                        {
                            keyBytes = key.ExportEncryptedPkcs8PrivateKey(password, new PbeParameters(PbeEncryptionAlgorithm.Aes256Cbc, HashAlgorithmName.SHA256, 100000));
                            pem = PemEncoding.Write("ENCRYPTED PRIVATE KEY", keyBytes);
                            pemEnvelope = Encoding.ASCII.GetBytes(pem);
                        }
                        else
                        {
                            keyBytes = key.ExportEncryptedPkcs8PrivateKey(string.Empty, new PbeParameters(PbeEncryptionAlgorithm.Aes256Cbc, HashAlgorithmName.SHA256, 1));
                            pem = PemEncoding.Write("ENCRYPTED PRIVATE KEY", keyBytes);
                            key.Dispose();
                            key = RSA.Create();
                            key.ImportFromEncryptedPem(pem, string.Empty);
                            Array.Clear(keyBytes, 0, keyBytes.Length);
                            Array.Clear(pem, 0, pem.Length);
                            keyBytes = key.ExportPkcs8PrivateKey();
                            pem = PemEncoding.Write("PRIVATE KEY", keyBytes);
                            pemEnvelope = Encoding.ASCII.GetBytes(pem);
                        }
 
                        Array.Clear(keyBytes, 0, keyBytes.Length);
                        Array.Clear(pem, 0, pem.Length);
 
                        bytes = Encoding.ASCII.GetBytes(PemEncoding.Write("CERTIFICATE", certificate.Export(X509ContentType.Cert)));
                        break;
                    default:
                        throw new InvalidOperationException("Unknown format.");
                }
            }
            else
            {
                if (format == CertificateKeyExportFormat.Pem)
                {
                    bytes = Encoding.ASCII.GetBytes(PemEncoding.Write("CERTIFICATE", certificate.Export(X509ContentType.Cert)));
                }
                else
                {
                    bytes = certificate.Export(X509ContentType.Cert);
                }
            }
        }
        catch (Exception e) when (Log.IsEnabled())
        {
            Log.ExportCertificateError(e.ToString());
            throw;
        }
        finally
        {
            key?.Dispose();
        }
 
        try
        {
            Log.WriteCertificateToDisk(path);
            AddCreatedFile(path, bytes);
        }
        catch (Exception ex) when (Log.IsEnabled())
        {
            Log.WriteCertificateToDiskError(ex.ToString());
            throw;
        }
        finally
        {
            Array.Clear(bytes, 0, bytes.Length);
        }
 
        if (includePrivateKey && format == CertificateKeyExportFormat.Pem)
        {
            if (pemEnvelope is null)
            {
                throw new InvalidOperationException("Missing PEM key envelope.");
            }
 
            try
            {
                var keyPath = Path.ChangeExtension(path, ".key");
                Log.WritePemKeyToDisk(keyPath);
                AddCreatedFile(keyPath, pemEnvelope);
            }
            catch (Exception ex) when (Log.IsEnabled())
            {
                Log.WritePemKeyToDiskError(ex.ToString());
                throw;
            }
            finally
            {
                Array.Clear(pemEnvelope, 0, pemEnvelope.Length);
            }
        }
    }
 
    protected override void PopulateCertificatesFromStore(X509Store store, List<X509Certificate2> certificates, bool requireExportable)
    {
        if (!Enum.TryParse<StoreName>(store.Name, ignoreCase: true, out var storeName))
        {
            return;
        }
 
        var storeKey = new StoreKey(storeName, store.Location);
        if (!_stores.TryGetValue(storeKey, out var entries))
        {
            return;
        }
 
        foreach (var entry in entries)
        {
            certificates.Add(entry.CreateCertificate());
        }
    }
 
    protected override X509Certificate2 SaveCertificateCore(X509Certificate2 certificate, StoreName storeName, StoreLocation storeLocation)
    {
        AddCertificate(storeName, storeLocation, certificate, isExportable: true);
        return certificate;
    }
 
    protected override TrustLevel TrustCertificateCore(X509Certificate2 certificate)
    {
        _trustedCertificates.Add(certificate);
        AddCertificate(StoreName.Root, StoreLocation.CurrentUser, certificate, isExportable: true);
        return TrustResult;
    }
 
    public override TrustLevel GetTrustLevel(X509Certificate2 certificate)
    {
        return IsCertificateInStore(StoreName.Root, certificate) ? TrustLevel.Full : TrustLevel.None;
    }
 
    internal override bool IsExportable(X509Certificate2 c)
    {
        if (!string.IsNullOrEmpty(c.Thumbprint) && _exportableByThumbprint.TryGetValue(c.Thumbprint, out var exportable))
        {
            return exportable;
        }
 
        return c.HasPrivateKey;
    }
 
    protected override void RemoveCertificateFromTrustedRoots(X509Certificate2 certificate)
    {
        RemoveFromStores(StoreName.Root, certificate);
    }
 
    protected override IList<X509Certificate2> GetCertificatesToRemove(StoreName storeName, StoreLocation storeLocation)
    {
        if (!_stores.TryGetValue(new StoreKey(storeName, storeLocation), out var entries))
        {
            return Array.Empty<X509Certificate2>();
        }
 
        return entries.Select(entry => entry.CreateCertificate()).ToArray();
    }
 
    protected override void CreateDirectoryWithPermissions(string directoryPath)
    {
        _createdDirectories.Add(directoryPath);
    }
 
    protected override void RemoveCertificateFromUserStoreCore(X509Certificate2 certificate)
    {
        RemoveFromStore(StoreName.My, StoreLocation.CurrentUser, certificate);
    }
 
    internal override CheckCertificateStateResult CheckCertificateState(X509Certificate2 candidate)
    {
        return CheckCertificateStateOverride?.Invoke(candidate) ?? new CheckCertificateStateResult(success: true, failureMessage: null);
    }
 
    internal override void CorrectCertificateState(X509Certificate2 candidate)
    {
        _correctedCertificates.Add(candidate);
    }
 
    private List<InMemoryCertificateEntry> GetOrCreateStore(StoreName storeName, StoreLocation storeLocation)
    {
        var key = new StoreKey(storeName, storeLocation);
        if (!_stores.TryGetValue(key, out var entries))
        {
            entries = new List<InMemoryCertificateEntry>();
            _stores[key] = entries;
        }
 
        return entries;
    }
 
    private void AddCreatedFile(string path, byte[] bytes)
    {
        _createdFiles[path] = (byte[])bytes.Clone();
    }
 
    private void RemoveFromStores(StoreName storeName, X509Certificate2 certificate)
    {
        foreach (var key in _stores.Keys.Where(key => key.StoreName == storeName).ToArray())
        {
            RemoveFromStore(key.StoreName, key.StoreLocation, certificate);
        }
    }
 
    private bool IsCertificateInStore(StoreName storeName, X509Certificate2 certificate)
    {
        foreach (var entry in _stores.Where(pair => pair.Key.StoreName == storeName).SelectMany(pair => pair.Value))
        {
            if (string.Equals(entry.SerialNumber, certificate.SerialNumber, StringComparison.OrdinalIgnoreCase) ||
                certificate.RawDataMemory.Span.SequenceEqual(entry.CertBytes))
            {
                return true;
            }
        }
 
        return false;
    }
 
    private void RemoveFromStore(StoreName storeName, StoreLocation storeLocation, X509Certificate2 certificate)
    {
        if (!_stores.TryGetValue(new StoreKey(storeName, storeLocation), out var entries))
        {
            return;
        }
 
        for (var i = entries.Count - 1; i >= 0; i--)
        {
            var entry = entries[i];
            if (string.Equals(entry.SerialNumber, certificate.SerialNumber, StringComparison.OrdinalIgnoreCase) ||
                certificate.RawDataMemory.Span.SequenceEqual(entry.CertBytes))
            {
                entries.RemoveAt(i);
                if (!string.IsNullOrEmpty(entry.Thumbprint))
                {
                    _exportableByThumbprint.Remove(entry.Thumbprint);
                }
            }
        }
    }
 
    internal readonly record struct StoreKey(StoreName StoreName, StoreLocation StoreLocation);
 
    private sealed class InMemoryCertificateEntry
    {
        public InMemoryCertificateEntry(X509Certificate2 certificate, bool? isExportable)
        {
            SerialNumber = certificate.SerialNumber;
            Thumbprint = certificate.Thumbprint ?? string.Empty;
            Exportable = (isExportable ?? certificate.HasPrivateKey) && certificate.HasPrivateKey;
            CertBytes = certificate.Export(X509ContentType.Cert);
            if (Exportable)
            {
                PfxBytes = certificate.Export(X509ContentType.Pkcs12);
            }
        }
 
        public string SerialNumber { get; }
 
        public string Thumbprint { get; }
 
        public bool Exportable { get; }
 
        public byte[] CertBytes { get; }
 
        public byte[]? PfxBytes { get; }
 
        public X509Certificate2 CreateCertificate()
        {
            if (PfxBytes != null)
            {
                try
                {
                    return new X509Certificate2(PfxBytes, (string?)null, X509KeyStorageFlags.Exportable | X509KeyStorageFlags.EphemeralKeySet);
                }
                catch (PlatformNotSupportedException)
                {
                    return new X509Certificate2(PfxBytes, (string?)null, X509KeyStorageFlags.Exportable);
                }
            }
 
            return new X509Certificate2(CertBytes);
        }
    }
}