File: RequestProcessing\ClientCertLoader.cs
Web Access
Project: src\src\Servers\HttpSys\src\Microsoft.AspNetCore.Server.HttpSys.csproj (Microsoft.AspNetCore.Server.HttpSys)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Contracts;
using System.Runtime.InteropServices;
using System.Security;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using Windows.Win32.Networking.HttpServer;
 
namespace Microsoft.AspNetCore.Server.HttpSys;
 
// This class is used to load the client certificate on-demand.  Because client certs are optional, all
// failures are handled internally and reported via ClientCertException or ClientCertError.
internal sealed unsafe partial class ClientCertLoader : IAsyncResult, IDisposable
{
    private const uint CertBlobSize = 1500;
    private static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(WaitCallback);
 
    private SafeNativeOverlapped? _overlapped;
    private byte[]? _backingBuffer;
    private HTTP_SSL_CLIENT_CERT_INFO* _memoryBlob;
    private uint _size;
    private readonly TaskCompletionSource<object?> _tcs;
    private readonly RequestContext _requestContext;
 
    private int _clientCertError;
    private X509Certificate2? _clientCert;
    private Exception? _clientCertException;
    private readonly CancellationTokenRegistration _cancellationRegistration;
 
    internal ClientCertLoader(RequestContext requestContext, CancellationToken cancellationToken)
    {
        _requestContext = requestContext;
        _tcs = new TaskCompletionSource<object?>();
        // we will use this overlapped structure to issue async IO to ul
        // the event handle will be put in by the BeginHttpApi2.ERROR_SUCCESS() method
        Reset(CertBlobSize);
 
        if (cancellationToken.CanBeCanceled)
        {
            _cancellationRegistration = RequestContext.RegisterForCancellation(cancellationToken);
        }
    }
 
    internal SafeHandle RequestQueueHandle => _requestContext.Server.RequestQueue.Handle;
 
    internal X509Certificate2? ClientCert
    {
        get
        {
            Contract.Assert(Task.IsCompleted);
            return _clientCert;
        }
    }
 
    internal int ClientCertError
    {
        get
        {
            Contract.Assert(Task.IsCompleted);
            return _clientCertError;
        }
    }
 
    internal Exception? ClientCertException
    {
        get
        {
            Contract.Assert(Task.IsCompleted);
            return _clientCertException;
        }
    }
 
    private RequestContext RequestContext
    {
        get
        {
            return _requestContext;
        }
    }
 
    private Task Task
    {
        get
        {
            return _tcs.Task;
        }
    }
 
    private SafeNativeOverlapped? NativeOverlapped
    {
        get
        {
            return _overlapped;
        }
    }
 
    private HTTP_SSL_CLIENT_CERT_INFO* RequestBlob
    {
        get
        {
            return _memoryBlob;
        }
    }
 
    private void Reset(uint size)
    {
        if (size == _size)
        {
            return;
        }
        if (_size != 0)
        {
            _overlapped!.Dispose();
        }
        _size = size;
        if (size == 0)
        {
            _overlapped = null;
            _memoryBlob = null;
            _backingBuffer = null;
            return;
        }
        _backingBuffer = new byte[checked((int)size)];
        var boundHandle = RequestContext.Server.RequestQueue.BoundHandle;
        _overlapped = new SafeNativeOverlapped(boundHandle,
            boundHandle.AllocateNativeOverlapped(IOCallback, this, _backingBuffer));
        _memoryBlob = (HTTP_SSL_CLIENT_CERT_INFO*)Marshal.UnsafeAddrOfPinnedArrayElement(_backingBuffer, 0);
    }
 
    // When you use netsh to configure HTTP.SYS with clientcertnegotiation = enable
    // which means negotiate client certificates, when the client makes the
    // initial SSL connection, the server (HTTP.SYS) requests the client certificate.
    //
    // Some apps may not want to negotiate the client cert at the beginning,
    // perhaps serving the default.htm. In this case the HTTP.SYS is configured
    // with clientcertnegotiation = disabled, which means that the client certificate is
    // optional so initially when SSL is established HTTP.SYS won't ask for client
    // certificate. This works fine for the default.htm in the case above,
    // however, if the app wants to demand a client certificate at a later time
    // perhaps showing "YOUR ORDERS" page, then the server wants to negotiate
    // Client certs. This will in turn makes HTTP.SYS to do the
    // SEC_I_RENOGOTIATE through which the client cert demand is made
    //
    // NOTE: When calling HttpReceiveClientCertificate you can get
    // ERROR_NOT_FOUND - which means the client did not provide the cert
    // If this is important, the server should respond with 403 forbidden
    // HTTP.SYS will not do this for you automatically
    internal Task LoadClientCertificateAsync()
    {
        var size = CertBlobSize;
        bool retry;
        do
        {
            retry = false;
            uint bytesReceived = 0;
 
            var statusCode =
                HttpApi.HttpReceiveClientCertificate(
                    RequestQueueHandle,
                    RequestContext.Request.UConnectionId,
                    0u,
                    RequestBlob,
                    size,
                    &bytesReceived,
                    NativeOverlapped!);
 
            if (statusCode == ErrorCodes.ERROR_MORE_DATA)
            {
                var pClientCertInfo = RequestBlob;
                size = bytesReceived + pClientCertInfo->CertEncodedSize;
                Reset(size);
                retry = true;
            }
            else if (statusCode == ErrorCodes.ERROR_NOT_FOUND)
            {
                // The client did not send a cert.
                Complete(0, null);
            }
            else if (statusCode == ErrorCodes.ERROR_SUCCESS &&
                HttpSysListener.SkipIOCPCallbackOnSuccess)
            {
                IOCompleted(statusCode, bytesReceived);
            }
            else if (statusCode != ErrorCodes.ERROR_SUCCESS &&
                statusCode != ErrorCodes.ERROR_IO_PENDING)
            {
                // Some other bad error, possible(?) return values are:
                // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED
                // Also ERROR_BAD_DATA if we got it twice or it reported smaller size buffer required.
                Fail(new HttpSysException((int)statusCode));
            }
        }
        while (retry);
 
        return Task;
    }
 
    private void Complete(int certErrors, X509Certificate2? cert)
    {
        // May be null
        _clientCert = cert;
        _clientCertError = certErrors;
        Dispose();
        _tcs.TrySetResult(null);
    }
 
    private void Fail(Exception ex)
    {
        // TODO: Log
        _clientCertException = ex;
        Dispose();
        _tcs.TrySetResult(null);
    }
 
    private unsafe void IOCompleted(uint errorCode, uint numBytes)
    {
        IOCompleted(this, errorCode, numBytes);
    }
 
    [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirected to callback")]
    private static unsafe void IOCompleted(ClientCertLoader asyncResult, uint errorCode, uint numBytes)
    {
        var requestContext = asyncResult.RequestContext;
        try
        {
            if (errorCode == ErrorCodes.ERROR_MORE_DATA)
            {
                // There is a bug that has existed in http.sys since w2k3.  Bytesreceived will only
                // return the size of the initial cert structure.  To get the full size,
                // we need to add the certificate encoding size as well.
 
                var pClientCertInfo = asyncResult.RequestBlob;
                asyncResult.Reset(numBytes + pClientCertInfo->CertEncodedSize);
 
                uint bytesReceived = 0;
                errorCode =
                    HttpApi.HttpReceiveClientCertificate(
                        requestContext.Server.RequestQueue.Handle,
                        requestContext.Request.UConnectionId,
                        0u,
                        asyncResult._memoryBlob,
                        asyncResult._size,
                        &bytesReceived,
                        asyncResult._overlapped!);
 
                if (errorCode == ErrorCodes.ERROR_IO_PENDING ||
                   (errorCode == ErrorCodes.ERROR_SUCCESS && !HttpSysListener.SkipIOCPCallbackOnSuccess))
                {
                    return;
                }
            }
 
            if (errorCode == ErrorCodes.ERROR_NOT_FOUND)
            {
                // The client did not send a cert.
                asyncResult.Complete(0, null);
            }
            else if (errorCode != ErrorCodes.ERROR_SUCCESS)
            {
                asyncResult.Fail(new HttpSysException((int)errorCode));
            }
            else
            {
                var pClientCertInfo = asyncResult._memoryBlob;
                if (pClientCertInfo == null)
                {
                    asyncResult.Complete(0, null);
                }
                else
                {
                    if (pClientCertInfo->pCertEncoded != null)
                    {
                        try
                        {
                            var certEncoded = new byte[pClientCertInfo->CertEncodedSize];
                            Marshal.Copy((IntPtr)pClientCertInfo->pCertEncoded, certEncoded, 0, certEncoded.Length);
                            asyncResult.Complete((int)pClientCertInfo->CertFlags, new X509Certificate2(certEncoded));
                        }
                        catch (CryptographicException exception)
                        {
                            // TODO: Log
                            asyncResult.Fail(exception);
                        }
                        catch (SecurityException exception)
                        {
                            // TODO: Log
                            asyncResult.Fail(exception);
                        }
                    }
                }
            }
        }
        catch (Exception exception)
        {
            asyncResult.Fail(exception);
        }
    }
 
    private static unsafe void WaitCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped)
    {
        var asyncResult = (ClientCertLoader)ThreadPoolBoundHandle.GetNativeOverlappedState(nativeOverlapped)!;
        IOCompleted(asyncResult, errorCode, numBytes);
    }
 
    public void Dispose()
    {
        Dispose(true);
    }
 
    private void Dispose(bool disposing)
    {
        if (disposing)
        {
            _cancellationRegistration.Dispose();
            if (_overlapped != null)
            {
                _memoryBlob = null;
                _overlapped.Dispose();
            }
        }
    }
 
    public object? AsyncState
    {
        get { return _tcs.Task.AsyncState; }
    }
 
    public WaitHandle AsyncWaitHandle
    {
        get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; }
    }
 
    public bool CompletedSynchronously
    {
        get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; }
    }
 
    public bool IsCompleted
    {
        get { return _tcs.Task.IsCompleted; }
    }
}