File: src\Servers\Kestrel\shared\test\RevocationResponder.cs
Web Access
Project: src\src\Servers\Kestrel\test\InMemory.FunctionalTests\InMemory.FunctionalTests.csproj (InMemory.FunctionalTests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Formats.Asn1;
using System.Net;
using System.Security.Cryptography;
using System.Web;
 
namespace Microsoft.AspNetCore.InternalTesting;
 
// Copied from https://github.com/dotnet/runtime/blob/main/src/libraries/Common/tests/System/Security/Cryptography/X509Certificates/RevocationResponder.cs
internal sealed class RevocationResponder : IDisposable
{
    private static readonly bool s_traceEnabled =
        Environment.GetEnvironmentVariable("TRACE_REVOCATION_RESPONSE") != null;
 
    private readonly HttpListener _listener;
 
    private readonly Dictionary<string, CertificateAuthority> _aiaPaths =
        new Dictionary<string, CertificateAuthority>();
 
    private readonly Dictionary<string, CertificateAuthority> _crlPaths
        = new Dictionary<string, CertificateAuthority>();
 
    private readonly List<(string, CertificateAuthority)> _ocspAuthorities =
        new List<(string, CertificateAuthority)>();
 
    public string UriPrefix { get; }
 
    public bool RespondEmpty { get; set; }
 
    public TimeSpan ResponseDelay { get; set; }
    public DelayedActionsFlag DelayedActions { get; set; }
 
    private RevocationResponder(HttpListener listener, string uriPrefix)
    {
        _listener = listener;
        UriPrefix = uriPrefix;
    }
 
    public void Dispose()
    {
        _listener.Close();
    }
 
    internal void AddCertificateAuthority(CertificateAuthority authority)
    {
        if (authority.AiaHttpUri != null && authority.AiaHttpUri.StartsWith(UriPrefix, StringComparison.OrdinalIgnoreCase))
        {
            string path = authority.AiaHttpUri.Substring(UriPrefix.Length - 1);
            Trace($"Adding AIA path : {path}");
            _aiaPaths.Add(path, authority);
        }
 
        if (authority.CdpUri != null && authority.CdpUri.StartsWith(UriPrefix, StringComparison.OrdinalIgnoreCase))
        {
            string path = authority.CdpUri.Substring(UriPrefix.Length - 1);
            Trace($"Adding CRL path : {path}");
            _crlPaths.Add(path, authority);
        }
 
        if (authority.OcspUri != null && authority.OcspUri.StartsWith(UriPrefix, StringComparison.OrdinalIgnoreCase))
        {
            string path = authority.OcspUri.Substring(UriPrefix.Length - 1);
            Trace($"Adding OCSP path : {path}");
            _ocspAuthorities.Add((path, authority));
        }
    }
 
    private void HandleRequests()
    {
        ThreadPool.QueueUserWorkItem(
            state =>
            {
                while (state._listener.IsListening)
                {
                    state.HandleRequest();
                }
            },
            this,
            true);
    }
 
    internal void HandleRequest()
    {
        HttpListenerContext context = null;
 
        try
        {
            context = _listener.GetContext();
        }
        catch (Exception)
        {
        }
 
        if (context != null)
        {
            ThreadPool.QueueUserWorkItem(
                state => HandleRequest(state),
                context,
                true);
        }
    }
 
    internal async Task HandleRequestAsync()
    {
        HttpListenerContext context = null;
 
        try
        {
            context = await _listener.GetContextAsync();
        }
        catch (Exception)
        {
        }
 
        if (context != null)
        {
            ThreadPool.QueueUserWorkItem(
                state => HandleRequest(state),
                context,
                true);
        }
    }
 
    internal void HandleRequest(HttpListenerContext context)
    {
        bool responded = false;
        try
        {
            Trace($"{context.Request.HttpMethod} {context.Request.RawUrl} (HTTP {context.Request.ProtocolVersion})");
            HandleRequest(context, ref responded);
        }
        catch (Exception e)
        {
            try
            {
                if (!responded && context != null)
                {
                    context.Response.StatusCode = 500;
                    context.Response.StatusDescription = "Internal Server Error";
                    context.Response.Close();
 
                    Trace($"Sent 500 due to exception on {context.Request.HttpMethod} {context.Request.RawUrl}");
                    Trace(e.ToString());
                }
            }
            catch (Exception)
            {
            }
 
            return;
        }
 
        if (!responded)
        {
            Trace($"404 for {context.Request.HttpMethod} {context.Request.RawUrl}");
 
            try
            {
                context.Response.StatusCode = 404;
                context.Response.Close();
            }
            catch (Exception)
            {
            }
        }
    }
 
    private void HandleRequest(HttpListenerContext context, ref bool responded)
    {
        CertificateAuthority authority;
        string url = context.Request.RawUrl;
 
        if (_aiaPaths.TryGetValue(url, out authority))
        {
            if (DelayedActions.HasFlag(DelayedActionsFlag.Aia))
            {
                Trace($"Delaying response by {ResponseDelay}.");
                Thread.Sleep(ResponseDelay);
            }
 
            byte[] certData = RespondEmpty ? Array.Empty<byte>() : authority.GetCertData();
 
            responded = true;
            context.Response.StatusCode = 200;
            context.Response.ContentType = "application/pkix-cert";
            context.Response.Close(certData, willBlock: true);
            Trace($"Responded with {certData.Length}-byte certificate from {authority.SubjectName}.");
            return;
        }
 
        if (_crlPaths.TryGetValue(url, out authority))
        {
            if (DelayedActions.HasFlag(DelayedActionsFlag.Crl))
            {
                Trace($"Delaying response by {ResponseDelay}.");
                Thread.Sleep(ResponseDelay);
            }
 
            byte[] crl = RespondEmpty ? Array.Empty<byte>() : authority.GetCrl();
 
            responded = true;
            context.Response.StatusCode = 200;
            context.Response.ContentType = "application/pkix-crl";
            context.Response.Close(crl, willBlock: true);
            Trace($"Responded with {crl.Length}-byte CRL from {authority.SubjectName}.");
            return;
        }
 
        string prefix;
 
        foreach (var tuple in _ocspAuthorities)
        {
            (prefix, authority) = tuple;
 
            if (url.StartsWith(prefix, StringComparison.OrdinalIgnoreCase))
            {
                byte[] reqBytes;
                if (TryGetOcspRequestBytes(context.Request, prefix, out reqBytes))
                {
                    ReadOnlyMemory<byte> certId;
                    ReadOnlyMemory<byte> nonce;
                    try
                    {
                        DecodeOcspRequest(reqBytes, out certId, out nonce);
                    }
                    catch (Exception e)
                    {
                        Trace($"OcspRequest Decode failed ({url}) - {e}");
                        context.Response.StatusCode = 400;
                        context.Response.Close();
                        return;
                    }
 
                    byte[] ocspResponse = RespondEmpty ? Array.Empty<byte>() : authority.BuildOcspResponse(certId, nonce);
 
                    if (DelayedActions.HasFlag(DelayedActionsFlag.Ocsp))
                    {
                        Trace($"Delaying response by {ResponseDelay}.");
                        Thread.Sleep(ResponseDelay);
                    }
 
                    responded = true;
                    context.Response.StatusCode = 200;
                    context.Response.StatusDescription = "OK";
                    context.Response.ContentType = "application/ocsp-response";
                    context.Response.Close(ocspResponse, willBlock: true);
 
                    if (authority.HasOcspDelegation)
                    {
                        Trace($"OCSP Response: {ocspResponse.Length} bytes from {authority.SubjectName} delegated to {authority.OcspResponderSubjectName}");
                    }
                    else
                    {
                        Trace($"OCSP Response: {ocspResponse.Length} bytes from {authority.SubjectName}");
                    }
 
                    return;
                }
            }
        }
    }
 
    internal static RevocationResponder CreateAndListen()
    {
        HttpListener listener = OpenListener(out string uriPrefix);
 
        RevocationResponder responder = new RevocationResponder(listener, uriPrefix);
        responder.HandleRequests();
        return responder;
    }
 
    private static HttpListener OpenListener(out string uriPrefix)
    {
        while (true)
        {
            int port = RandomNumberGenerator.GetInt32(41000, 42000);
            uriPrefix = $"http://127.0.0.1:{port}/";
 
            HttpListener listener = new HttpListener();
            listener.Prefixes.Add(uriPrefix);
            listener.IgnoreWriteExceptions = true;
 
            try
            {
                listener.Start();
                Trace($"Listening at {uriPrefix}");
                return listener;
            }
            catch
            {
            }
        }
    }
 
    private static bool TryGetOcspRequestBytes(HttpListenerRequest request, string prefix, out byte[] requestBytes)
    {
        requestBytes = null;
        try
        {
            if (request.HttpMethod == "GET")
            {
                string base64 = HttpUtility.UrlDecode(request.RawUrl.Substring(prefix.Length + 1));
                requestBytes = Convert.FromBase64String(base64);
                return true;
            }
            else if (request.HttpMethod == "POST" && request.ContentType == "application/ocsp-request")
            {
                using (System.IO.Stream stream = request.InputStream)
                {
                    requestBytes = new byte[request.ContentLength64];
                    int read = stream.Read(requestBytes, 0, requestBytes.Length);
                    System.Diagnostics.Debug.Assert(read == requestBytes.Length);
                    return true;
                }
            }
        }
        catch (Exception e)
        {
            Trace($"Failed to get OCSP request bytes ({request.RawUrl}) - {e}");
        }
 
        return false;
    }
 
    private static void DecodeOcspRequest(
        byte[] requestBytes,
        out ReadOnlyMemory<byte> certId,
        out ReadOnlyMemory<byte> nonceExtension)
    {
        Asn1Tag context0 = new Asn1Tag(TagClass.ContextSpecific, 0);
        Asn1Tag context1 = new Asn1Tag(TagClass.ContextSpecific, 1);
 
        AsnReader reader = new AsnReader(requestBytes, AsnEncodingRules.DER);
        AsnReader request = reader.ReadSequence();
        reader.ThrowIfNotEmpty();
 
        AsnReader tbsRequest = request.ReadSequence();
 
        if (request.HasData)
        {
            // Optional signature
            request.ReadEncodedValue();
            request.ThrowIfNotEmpty();
        }
 
        // Only v1(0) is supported, and it shouldn't be written per DER.
        // But Apple writes it anyways, so let's go ahead and be lenient.
        if (tbsRequest.PeekTag().HasSameClassAndValue(context0))
        {
            AsnReader versionReader = tbsRequest.ReadSequence(context0);
 
            if (!versionReader.TryReadInt32(out int version) || version != 0)
            {
                throw new CryptographicException("ASN1 corrupted data");
            }
 
            versionReader.ThrowIfNotEmpty();
        }
 
        if (tbsRequest.PeekTag().HasSameClassAndValue(context1))
        {
            tbsRequest.ReadEncodedValue();
        }
 
        AsnReader requestList = tbsRequest.ReadSequence();
        AsnReader requestExtensions = null;
 
        if (tbsRequest.HasData)
        {
            AsnReader requestExtensionsWrapper = tbsRequest.ReadSequence(new Asn1Tag(TagClass.ContextSpecific, 2));
            requestExtensions = requestExtensionsWrapper.ReadSequence();
            requestExtensionsWrapper.ThrowIfNotEmpty();
        }
 
        tbsRequest.ThrowIfNotEmpty();
 
        AsnReader firstRequest = requestList.ReadSequence();
        requestList.ThrowIfNotEmpty();
 
        certId = firstRequest.ReadEncodedValue();
 
        if (firstRequest.HasData)
        {
            firstRequest.ReadSequence(context0);
        }
 
        firstRequest.ThrowIfNotEmpty();
 
        nonceExtension = default;
 
        if (requestExtensions != null)
        {
            while (requestExtensions.HasData)
            {
                ReadOnlyMemory<byte> wholeExtension = requestExtensions.PeekEncodedValue();
                AsnReader extension = requestExtensions.ReadSequence();
 
                if (extension.ReadObjectIdentifier() == "1.3.6.1.5.5.7.48.1.2")
                {
                    nonceExtension = wholeExtension;
                }
            }
        }
    }
 
    internal void Stop() => _listener.Stop();
 
    private static void Trace(string trace)
    {
        if (s_traceEnabled)
        {
            Console.WriteLine(trace);
        }
    }
}
 
public enum DelayedActionsFlag : byte
{
    None = 0,
    Ocsp = 0b1,
    Crl = 0b10,
    Aia = 0b100,
    All = 0b11111111
}