File: AsyncAcceptContext.cs
Web Access
Project: src\src\Servers\HttpSys\src\Microsoft.AspNetCore.Server.HttpSys.csproj (Microsoft.AspNetCore.Server.HttpSys)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Threading.Tasks.Sources;
using Microsoft.Extensions.Logging;
 
namespace Microsoft.AspNetCore.Server.HttpSys;
 
internal sealed unsafe partial class AsyncAcceptContext : IValueTaskSource<RequestContext>, IDisposable
{
    private static readonly IOCompletionCallback IOCallback = IOWaitCallback;
    private readonly PreAllocatedOverlapped _preallocatedOverlapped;
    private readonly IRequestContextFactory _requestContextFactory;
    private readonly ILogger _logger;
    private int _expectedCompletionCount;
 
    private NativeOverlapped* _overlapped;
 
    private readonly bool _logExpectationFailures = AppContext.TryGetSwitch(
        "Microsoft.AspNetCore.Server.HttpSys.LogAcceptExpectationFailure", out var enabled) && enabled;
 
    // mutable struct; do not make this readonly
    private ManualResetValueTaskSourceCore<RequestContext> _mrvts = new()
    {
        // We want to run continuations on the IO threads
        RunContinuationsAsynchronously = false
    };
 
    private RequestContext? _requestContext;
 
    internal AsyncAcceptContext(HttpSysListener server, IRequestContextFactory requestContextFactory, ILogger logger)
    {
        Server = server;
        _requestContextFactory = requestContextFactory;
        _preallocatedOverlapped = new(IOCallback, state: this, pinData: null);
        _logger = logger;
    }
 
    internal HttpSysListener Server { get; }
 
    internal ValueTask<RequestContext> AcceptAsync()
    {
        _mrvts.Reset();
 
        AllocateNativeRequest();
 
        var statusCode = QueueBeginGetContext();
        if (statusCode != ErrorCodes.ERROR_SUCCESS &&
            statusCode != ErrorCodes.ERROR_IO_PENDING)
        {
            // some other bad error, possible(?) return values are:
            // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED
            return ValueTask.FromException<RequestContext>(new HttpSysException((int)statusCode));
        }
 
        return new ValueTask<RequestContext>(this, _mrvts.Version);
    }
 
    private void IOCompleted(uint errorCode, uint numBytes, bool managed)
    {
        try
        {
            ObserveCompletion(managed); // expectation tracking
            if (errorCode != ErrorCodes.ERROR_SUCCESS &&
                errorCode != ErrorCodes.ERROR_MORE_DATA)
            {
                // (keep all the error handling in one place)
                throw new HttpSysException((int)errorCode);
            }
 
            Debug.Assert(_requestContext != null);
 
            if (errorCode == ErrorCodes.ERROR_SUCCESS)
            {
                var requestContext = _requestContext;
                // It's important that we clear the request context before we set the result
                // we want to reuse the acceptContext object for future accepts.
                _requestContext = null;
 
                try
                {
                    _mrvts.SetResult(requestContext);
                }
                catch (Exception ex)
                {
                    Log.AcceptSetResultFailed(_logger, ex);
                }
            }
            else
            {
                //  (uint)backingBuffer.Length - AlignmentPadding
                AllocateNativeRequest(numBytes, _requestContext.RequestId);
 
                // We need to issue a new request, either because auth failed, or because our buffer was too small the first time.
                var statusCode = QueueBeginGetContext();
 
                if (statusCode != ErrorCodes.ERROR_SUCCESS &&
                    statusCode != ErrorCodes.ERROR_IO_PENDING)
                {
                    // some other bad error, possible(?) return values are:
                    // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED
                    // (keep all the error handling in one place)
                    throw new HttpSysException((int)statusCode);
                }
            }
        }
        catch (Exception exception)
        {
            try
            {
                _mrvts.SetException(exception);
            }
            catch (Exception ex)
            {
                Log.AcceptSetResultFailed(_logger, ex);
            }
        }
    }
 
    private static unsafe void IOWaitCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped)
    {
        var acceptContext = (AsyncAcceptContext)ThreadPoolBoundHandle.GetNativeOverlappedState(nativeOverlapped)!;
        acceptContext.IOCompleted(errorCode, numBytes, false);
    }
 
    private void SetExpectCompletion() // we anticipate a completion *might* occur
    {
        // note this is intentionally a "reset and check" rather than Increment, so that we don't spam
        // the logs forever if a glitch occurs
        var value = Interlocked.Exchange(ref _expectedCompletionCount, 1); // should have been 0
        if (value != 0)
        {
            if (_logExpectationFailures)
            {
                Log.AcceptSetExpectationMismatch(_logger, value);
            }
            Debug.Assert(false, nameof(SetExpectCompletion)); // fail hard in debug
        }
    }
    private void CancelExpectCompletion() // due to error-code etc, we no longer anticipate a completion
    {
        var value = Interlocked.Decrement(ref _expectedCompletionCount); // should have been 1, so now 0
        if (value != 0)
        {
            if (_logExpectationFailures)
            {
                Log.AcceptCancelExpectationMismatch(_logger, value);
            }
            Debug.Assert(false, nameof(CancelExpectCompletion)); // fail hard in debug
        }
    }
    private void ObserveCompletion(bool managed) // a completion was invoked
    {
        var value = Interlocked.Decrement(ref _expectedCompletionCount); // should have been 1, so now 0
        if (value != 0)
        {
            if (_logExpectationFailures)
            {
                Log.AcceptObserveExpectationMismatch(_logger, managed ? "managed" : "unmanaged", value);
            }
            Debug.Assert(false, nameof(ObserveCompletion)); // fail hard in debug
        }
    }
 
    private uint QueueBeginGetContext()
    {
        uint statusCode;
        bool retry;
        do
        {
            Debug.Assert(_requestContext != null);
 
            retry = false;
            uint bytesTransferred = 0;
            SetExpectCompletion(); // track this *before*, because of timing vs IOCP (could even be effectively synchronous)
            statusCode = HttpApi.HttpReceiveHttpRequest(
                Server.RequestQueue.Handle,
                _requestContext.RequestId,
                // Small perf impact by not using HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY
                // if the request sends header+body in a single TCP packet
                0u,
                _requestContext.NativeRequest,
                _requestContext.Size,
                &bytesTransferred,
                _overlapped);
 
            switch (statusCode)
            {
                case (ErrorCodes.ERROR_CONNECTION_INVALID or ErrorCodes.ERROR_INVALID_PARAMETER) when _requestContext.RequestId != 0:
                    // ERROR_CONNECTION_INVALID:
                    // The client reset the connection between the time we got the MORE_DATA error and when we called HttpReceiveHttpRequest
                    // with the new buffer. We can clear the request id and move on to the next request.
                    //
                    // ERROR_INVALID_PARAMETER: Historical check from HttpListener.
                    // https://referencesource.microsoft.com/#System/net/System/Net/_ListenerAsyncResult.cs,137
                    // we might get this if somebody stole our RequestId,
                    // set RequestId to 0 and start all over again with the buffer we just allocated
                    // BUGBUG: how can someone steal our request ID?  seems really bad and in need of fix.
                    CancelExpectCompletion();
                    _requestContext.RequestId = 0;
                    retry = true;
                    break;
                case ErrorCodes.ERROR_MORE_DATA:
                    // the buffer was not big enough to fit the headers, we need
                    // to read the RequestId returned, allocate a new buffer of the required size
                    //  (uint)backingBuffer.Length - AlignmentPadding
                    CancelExpectCompletion(); // we'll "expect" again when we retry
                    AllocateNativeRequest(bytesTransferred);
                    retry = true;
                    break;
                case ErrorCodes.ERROR_SUCCESS:
                    if (HttpSysListener.SkipIOCPCallbackOnSuccess)
                    {
                        // IO operation completed synchronously - callback won't be called to signal completion.
                        IOCompleted(statusCode, bytesTransferred, true); // marks completion
                    }
                    // else: callback fired by IOCP (at some point), which marks completion
                    break;
                case ErrorCodes.ERROR_IO_PENDING:
                    break; // no change to state - callback will occur at some point
                default:
                    // fault code, not expecting an IOCP callback
                    CancelExpectCompletion();
                    break;
            }
        }
        while (retry);
        return statusCode;
    }
 
    private void AllocateNativeRequest(uint? size = null, ulong requestId = 0)
    {
        _requestContext?.ReleasePins();
        _requestContext?.Dispose();
 
        var boundHandle = Server.RequestQueue.BoundHandle;
        if (_overlapped != null)
        {
            boundHandle.FreeNativeOverlapped(_overlapped);
        }
 
        _requestContext = _requestContextFactory.CreateRequestContext(size, requestId);
        _overlapped = boundHandle.AllocateNativeOverlapped(_preallocatedOverlapped);
    }
 
    public void Dispose()
    {
        Dispose(true);
    }
 
    private void Dispose(bool disposing)
    {
        if (disposing)
        {
            if (_requestContext != null)
            {
                _requestContext.ReleasePins();
                _requestContext.Dispose();
                _requestContext = null;
            }
 
            var boundHandle = Server.RequestQueue.BoundHandle;
 
            if (_overlapped != null)
            {
                boundHandle.FreeNativeOverlapped(_overlapped);
                _overlapped = null;
            }
        }
    }
 
    public RequestContext GetResult(short token)
    {
        return _mrvts.GetResult(token);
    }
 
    public ValueTaskSourceStatus GetStatus(short token)
    {
        return _mrvts.GetStatus(token);
    }
 
    public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
    {
        _mrvts.OnCompleted(continuation, state, token, flags);
    }
}