File: RequestProcessing\ResponseStreamAsyncResult.cs
Web Access
Project: src\src\Servers\HttpSys\src\Microsoft.AspNetCore.Server.HttpSys.csproj (Microsoft.AspNetCore.Server.HttpSys)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Runtime.InteropServices;
using Windows.Win32.Foundation;
using Windows.Win32.Networking.HttpServer;
 
namespace Microsoft.AspNetCore.Server.HttpSys;
 
internal sealed unsafe partial class ResponseStreamAsyncResult : IAsyncResult, IDisposable
{
    private static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(Callback);
 
    private readonly SafeNativeOverlapped? _overlapped;
    private readonly HTTP_DATA_CHUNK[]? _dataChunks;
    private readonly FileStream? _fileStream;
    private readonly ResponseBody _responseStream;
    private readonly TaskCompletionSource<object?> _tcs;
    private uint _bytesSent;
    private readonly CancellationToken _cancellationToken;
    private readonly CancellationTokenRegistration _cancellationRegistration;
 
    internal ResponseStreamAsyncResult(ResponseBody responseStream, CancellationToken cancellationToken)
    {
        _responseStream = responseStream;
        _tcs = new TaskCompletionSource<object?>();
 
        var cancellationRegistration = default(CancellationTokenRegistration);
        if (cancellationToken.CanBeCanceled)
        {
            cancellationRegistration = _responseStream.RequestContext.RegisterForCancellation(cancellationToken);
        }
        _cancellationToken = cancellationToken;
        _cancellationRegistration = cancellationRegistration;
    }
 
    internal ResponseStreamAsyncResult(ResponseBody responseStream, ArraySegment<byte> data, bool chunked,
        CancellationToken cancellationToken)
        : this(responseStream, cancellationToken)
    {
        var boundHandle = _responseStream.RequestContext.Server.RequestQueue.BoundHandle;
        object[] objectsToPin;
 
        if (data.Count == 0)
        {
            _dataChunks = null;
            _overlapped = new SafeNativeOverlapped(boundHandle,
                boundHandle.AllocateNativeOverlapped(IOCallback, this, null));
            return;
        }
 
        _dataChunks = new HTTP_DATA_CHUNK[1 + (chunked ? 2 : 0)];
        objectsToPin = new object[_dataChunks.Length + 1];
        objectsToPin[0] = _dataChunks;
        var currentChunk = 0;
        var currentPin = 1;
 
        var chunkHeaderBuffer = new ArraySegment<byte>();
        if (chunked)
        {
            chunkHeaderBuffer = Helpers.GetChunkHeader(data.Count);
            SetDataChunk(_dataChunks, ref currentChunk, objectsToPin, ref currentPin, chunkHeaderBuffer);
        }
 
        SetDataChunk(_dataChunks, ref currentChunk, objectsToPin, ref currentPin, data);
 
        if (chunked)
        {
            SetDataChunkWithPinnedData(_dataChunks, ref currentChunk, Helpers.CRLF);
        }
 
        // This call will pin needed memory
        _overlapped = new SafeNativeOverlapped(boundHandle,
            boundHandle.AllocateNativeOverlapped(IOCallback, this, objectsToPin));
 
        currentChunk = 0;
        if (chunked)
        {
            _dataChunks[currentChunk].Anonymous.FromMemory.pBuffer = (void*)Marshal.UnsafeAddrOfPinnedArrayElement(chunkHeaderBuffer.Array!, chunkHeaderBuffer.Offset);
            currentChunk++;
        }
 
        _dataChunks[currentChunk].Anonymous.FromMemory.pBuffer = (void*)Marshal.UnsafeAddrOfPinnedArrayElement(data.Array!, data.Offset);
        currentChunk++;
 
        if (chunked)
        {
            // No need to update this chunk, CRLF, because it was already pinned.
            // However, we increment the counter to ensure we are accounting for all sections.
            currentChunk++;
        }
 
        Debug.Assert(currentChunk == _dataChunks.Length);
    }
 
    internal ResponseStreamAsyncResult(ResponseBody responseStream, FileStream fileStream, long offset,
        long count, bool chunked, CancellationToken cancellationToken)
        : this(responseStream, cancellationToken)
    {
        var boundHandle = responseStream.RequestContext.Server.RequestQueue.BoundHandle;
 
        _fileStream = fileStream;
 
        if (count == 0)
        {
            _dataChunks = null;
            _overlapped = new SafeNativeOverlapped(boundHandle,
                boundHandle.AllocateNativeOverlapped(IOCallback, this, null));
        }
        else
        {
            _dataChunks = new HTTP_DATA_CHUNK[chunked ? 3 : 1];
 
            object[] objectsToPin = new object[_dataChunks.Length];
            objectsToPin[_dataChunks.Length - 1] = _dataChunks;
 
            var chunkHeaderBuffer = new ArraySegment<byte>();
            if (chunked)
            {
                chunkHeaderBuffer = Helpers.GetChunkHeader(count);
                _dataChunks[0].DataChunkType = HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
                _dataChunks[0].Anonymous.FromMemory.BufferLength = (uint)chunkHeaderBuffer.Count;
                objectsToPin[0] = chunkHeaderBuffer.Array!;
 
                _dataChunks[1].DataChunkType = HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromFileHandle;
                _dataChunks[1].Anonymous.FromFileHandle.ByteRange.StartingOffset = (ulong)offset;
                _dataChunks[1].Anonymous.FromFileHandle.ByteRange.Length = (ulong)count;
                _dataChunks[1].Anonymous.FromFileHandle.FileHandle = (HANDLE)_fileStream.SafeFileHandle.DangerousGetHandle();
                // Nothing to pin for the file handle.
 
                // No need to pin the CRLF data
                int currentChunk = 2;
                SetDataChunkWithPinnedData(_dataChunks, ref currentChunk, Helpers.CRLF);
                Debug.Assert(currentChunk == _dataChunks.Length);
            }
            else
            {
                _dataChunks[0].DataChunkType = HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromFileHandle;
                _dataChunks[0].Anonymous.FromFileHandle.ByteRange.StartingOffset = (ulong)offset;
                _dataChunks[0].Anonymous.FromFileHandle.ByteRange.Length = (ulong)count;
                _dataChunks[0].Anonymous.FromFileHandle.FileHandle = (HANDLE)_fileStream.SafeFileHandle.DangerousGetHandle();
            }
 
            // This call will pin needed memory
            _overlapped = new SafeNativeOverlapped(boundHandle,
                boundHandle.AllocateNativeOverlapped(IOCallback, this, objectsToPin));
 
            if (chunked)
            {
                // This must be set after pinning with Overlapped.
                _dataChunks[0].Anonymous.FromMemory.pBuffer = (void*)Marshal.UnsafeAddrOfPinnedArrayElement(chunkHeaderBuffer.Array!, chunkHeaderBuffer.Offset);
            }
        }
    }
 
    private static void SetDataChunk(HTTP_DATA_CHUNK[] chunks, ref int chunkIndex, object[] objectsToPin, ref int pinIndex, ArraySegment<byte> segment)
    {
        objectsToPin[pinIndex] = segment.Array!;
        pinIndex++;
        ref var chunk = ref chunks[chunkIndex++];
        chunk.DataChunkType = HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
        // The address is not set until after we pin it with Overlapped
        chunk.Anonymous.FromMemory.BufferLength = (uint)segment.Count;
    }
 
    private static void SetDataChunkWithPinnedData(HTTP_DATA_CHUNK[] chunks, ref int chunkIndex, ReadOnlySpan<byte> bytes)
    {
        ref var chunk = ref chunks[chunkIndex++];
        chunk.DataChunkType = HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
        fixed (byte* ptr = bytes)
        {
            chunk.Anonymous.FromMemory.pBuffer = ptr;
        }
        chunk.Anonymous.FromMemory.BufferLength = (uint)bytes.Length;
    }
 
    internal SafeNativeOverlapped? NativeOverlapped
    {
        get { return _overlapped; }
    }
 
    internal Task Task
    {
        get { return _tcs.Task; }
    }
 
    internal uint BytesSent
    {
        get { return _bytesSent; }
        set { _bytesSent = value; }
    }
 
    internal ushort DataChunkCount
    {
        get
        {
            if (_dataChunks == null)
            {
                return 0;
            }
            else
            {
                return (ushort)_dataChunks.Length;
            }
        }
    }
 
    internal HTTP_DATA_CHUNK* DataChunks
    {
        get
        {
            if (_dataChunks == null)
            {
                return null;
            }
            else
            {
                return (HTTP_DATA_CHUNK*)(Marshal.UnsafeAddrOfPinnedArrayElement(_dataChunks, 0));
            }
        }
    }
 
    internal bool EndCalled { get; set; }
 
    internal void IOCompleted(uint errorCode)
    {
        IOCompleted(this, errorCode);
    }
 
    private static void IOCompleted(ResponseStreamAsyncResult asyncResult, uint errorCode)
    {
        var logger = asyncResult._responseStream.RequestContext.Logger;
        try
        {
            if (errorCode != ErrorCodes.ERROR_SUCCESS && errorCode != ErrorCodes.ERROR_HANDLE_EOF)
            {
                if (asyncResult._cancellationToken.IsCancellationRequested)
                {
                    Log.WriteCancelled(logger, errorCode);
                    asyncResult.Cancel(asyncResult._responseStream.ThrowWriteExceptions);
                }
                else if (asyncResult._responseStream.ThrowWriteExceptions)
                {
                    var exception = new IOException(string.Empty, new HttpSysException((int)errorCode));
                    Log.WriteError(logger, exception);
                    asyncResult.Fail(exception);
                }
                else
                {
                    Log.WriteErrorIgnored(logger, errorCode);
                    asyncResult.FailSilently();
                }
            }
            else
            {
                if (asyncResult._dataChunks == null)
                {
                    // TODO: Verbose log data written
                }
                else
                {
                    // TODO: Verbose log
                    // for (int i = 0; i < asyncResult._dataChunks.Length; i++)
                    // {
                    // Logging.Dump(Logging.HttpListener, asyncResult, "Callback", (IntPtr)asyncResult._dataChunks[0].fromMemory.pBuffer, (int)asyncResult._dataChunks[0].fromMemory.BufferLength);
                    // }
                }
                asyncResult.Complete();
            }
        }
        catch (Exception e)
        {
            Log.WriteError(logger, e);
            asyncResult.Fail(e);
        }
    }
 
    private static unsafe void Callback(uint errorCode, uint _, NativeOverlapped* nativeOverlapped)
    {
        var asyncResult = (ResponseStreamAsyncResult)ThreadPoolBoundHandle.GetNativeOverlappedState(nativeOverlapped)!;
        IOCompleted(asyncResult, errorCode);
    }
 
    internal void Complete()
    {
        Dispose();
        _tcs.TrySetResult(null);
    }
 
    internal void FailSilently()
    {
        Dispose();
        // Abort the request but do not close the stream, let future writes complete silently
        _responseStream.Abort(dispose: false);
        _tcs.TrySetResult(null);
    }
 
    internal void Cancel(bool dispose)
    {
        Dispose();
        _responseStream.Abort(dispose);
        _tcs.TrySetCanceled();
    }
 
    internal void Fail(Exception ex)
    {
        Dispose();
        _responseStream.Abort();
        _tcs.TrySetException(ex);
    }
 
    public object? AsyncState
    {
        get { return _tcs.Task.AsyncState; }
    }
 
    public WaitHandle AsyncWaitHandle
    {
        get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; }
    }
 
    public bool CompletedSynchronously
    {
        get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; }
    }
 
    public bool IsCompleted
    {
        get { return _tcs.Task.IsCompleted; }
    }
 
    public void Dispose()
    {
        _overlapped?.Dispose();
        _fileStream?.Dispose();
        _cancellationRegistration.Dispose();
    }
}