File: Core\IO\AsyncIOOperation.cs
Web Access
Project: src\src\Servers\IIS\IIS\src\Microsoft.AspNetCore.Server.IIS.csproj (Microsoft.AspNetCore.Server.IIS)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Threading.Tasks.Sources;
using Microsoft.AspNetCore.Connections;
 
namespace Microsoft.AspNetCore.Server.IIS.Core.IO;
 
internal abstract class AsyncIOOperation : IValueTaskSource<int>, IValueTaskSource
{
    private static readonly Action<object?> CallbackCompleted = _ => { Debug.Assert(false, "Should not be invoked"); };
 
    private Action<object?>? _continuation;
    private object? _state;
    private int _result;
 
    private Exception? _exception;
 
    public ValueTaskSourceStatus GetStatus(short token)
    {
        if (ReferenceEquals(Volatile.Read(ref _continuation), null))
        {
            return ValueTaskSourceStatus.Pending;
        }
 
        return _exception != null ? ValueTaskSourceStatus.Succeeded : ValueTaskSourceStatus.Faulted;
    }
 
    public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
    {
        if (_state != null)
        {
            ThrowMultipleContinuations();
        }
 
        _state = state;
 
        var previousContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null);
 
        if (previousContinuation != null)
        {
            if (!ReferenceEquals(previousContinuation, CallbackCompleted))
            {
                ThrowMultipleContinuations();
            }
 
            new AsyncContinuation(continuation, state).Invoke();
        }
    }
 
    private static void ThrowMultipleContinuations()
    {
        throw new InvalidOperationException("Multiple awaiters are not allowed");
    }
 
    void IValueTaskSource.GetResult(short token)
    {
        var exception = _exception;
 
        ResetOperation();
 
        if (exception != null)
        {
            throw exception;
        }
    }
 
    public int GetResult(short token)
    {
        var exception = _exception;
        var result = _result;
 
        ResetOperation();
 
        if (exception != null)
        {
            throw exception;
        }
 
        return result;
    }
 
    public AsyncContinuation? Invoke()
    {
        if (InvokeOperation(out var hr, out var bytes))
        {
            return Complete(hr, bytes);
        }
        return null;
    }
 
    protected abstract bool InvokeOperation(out int hr, out int bytes);
 
    public AsyncContinuation Complete(int hr, int bytes)
    {
        if (hr != NativeMethods.ERROR_OPERATION_ABORTED)
        {
            _result = bytes;
            if (hr != NativeMethods.HR_OK && !IsSuccessfulResult(hr))
            {
                // Treat all errors as the client disconnect
                _exception = new ConnectionResetException("The client has disconnected", Marshal.GetExceptionForHR(hr)!);
            }
        }
        else
        {
            _result = -1;
            _exception = null;
        }
 
        AsyncContinuation asyncContinuation = default;
        var continuation = Interlocked.CompareExchange(ref _continuation, CallbackCompleted, null);
        if (continuation != null)
        {
            asyncContinuation = new AsyncContinuation(continuation, _state);
        }
 
        FreeOperationResources(hr, bytes);
 
        return asyncContinuation;
    }
 
    protected virtual bool IsSuccessfulResult(int hr) => false;
 
    public virtual void FreeOperationResources(int hr, int bytes) { }
 
    protected virtual void ResetOperation()
    {
        _exception = null;
        _result = int.MinValue;
        _state = null;
        _continuation = null;
    }
 
    public bool InUse() => _continuation is not null;
 
    public readonly struct AsyncContinuation
    {
        public Action<object?> Continuation { get; }
        public object? State { get; }
 
        public AsyncContinuation(Action<object?> continuation, object? state)
        {
            Continuation = continuation;
            State = state;
        }
 
        public void Invoke()
        {
            if (Continuation != null)
            {
                ThreadPool.UnsafeQueueUserWorkItem(Continuation, State, preferLocal: false);
            }
        }
    }
}