File: System\Net\Quic\Internal\ResettableValueTaskSource.cs
Web Access
Project: src\src\libraries\System.Net.Quic\src\System.Net.Quic.csproj (System.Net.Quic)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;
 
namespace System.Net.Quic;
 
internal sealed class ResettableValueTaskSource : IValueTaskSource
{
    // None -> [TryGetValueTask] -> Awaiting -> [TrySetResult|TrySetException(final: false)] -> Ready -> [GetResult] -> None
    // None -> [TrySetResult|TrySetException(final: false)] -> Ready -> [TryGetValueTask] -> [GetResult] -> None
    // None|Awaiting -> [TrySetResult|TrySetException(final: true)] -> Completed(never leaves this state)
    // Ready -> [GetResult: TrySet*(final: true) was called] -> Completed(never leaves this state)
    private enum State
    {
        None,
        Awaiting,
        Ready,
        Completed
    }
 
    private State _state;
    private bool _hasWaiter;
    private ManualResetValueTaskSourceCore<bool> _valueTaskSource;
    private CancellationTokenRegistration _cancellationRegistration;
    private CancellationToken _cancelledToken;
    private Action<object?>? _cancellationAction;
    private GCHandle _keepAlive;
    private FinalTaskSource _finalTaskSource;
 
    public ResettableValueTaskSource()
    {
        _state = State.None;
        _hasWaiter = false;
        _valueTaskSource = new ManualResetValueTaskSourceCore<bool>() { RunContinuationsAsynchronously = true };
        _cancellationRegistration = default;
        _cancelledToken = default;
        _keepAlive = default;
        _finalTaskSource = new FinalTaskSource();
    }
 
    /// <summary>
    /// Allows setting additional cancellation action to be called if token passed to <see cref="TryGetValueTask(out ValueTask, object?, CancellationToken)"/> fires off.
    /// The argument for the action is the <c>keepAlive</c> object from the same <see cref="TryGetValueTask(out ValueTask, object?, CancellationToken)"/> call.
    /// </summary>
    public Action<object?> CancellationAction { init { _cancellationAction = value; } }
 
    /// <summary>
    /// Returns <c>true</c> is this task source has entered its final state, i.e. <see cref="TrySetResult(bool)"/> or <see cref="TrySetException(Exception, bool)"/>
    /// was called with <c>final</c> set to <c>true</c> and the result was propagated.
    /// </summary>
    public bool IsCompleted => (State)Volatile.Read(ref Unsafe.As<State, byte>(ref _state)) == State.Completed;
 
    /// <summary>
    /// Tries to get a value task representing this task source. If this task source is <see cref="State.None"/>, it'll also transition it into <see cref="State.Awaiting"/> state.
    /// It prevents concurrent operations from being invoked since it'll return <c>false</c> if the task source was already in <see cref="State.Awaiting"/> state.
    /// In other states, it'll return a value task representing this task source without any other work. So to determine whether to invoke a P/Invoke operation or not,
    /// the state of <paramref name="valueTask"/> must also be checked.
    /// </summary>
    /// <param name="valueTask">A value task representing the result. Only meaningful in case this method returns <c>true</c>. Might already be completed.</param>
    /// <param name="keepAlive">An object to hold during a P/Invoke call. It'll get release with setting the result/exception.</param>
    /// <param name="cancellationToken">A cancellation token which might cancel the value task.</param>
    /// <returns><c>true</c> if this is not an overlapping call (task source transitioned or was already set); otherwise, <c>false</c>.</returns>
    public bool TryGetValueTask(out ValueTask valueTask, object? keepAlive = null, CancellationToken cancellationToken = default)
    {
        lock (this)
        {
            // Cancellation might kick off synchronously, re-entering the lock and changing the state to completed.
            if (_state == State.None)
            {
                // Register cancellation if the token can be cancelled and the task is not completed yet.
                if (cancellationToken.CanBeCanceled)
                {
                    _cancellationRegistration = cancellationToken.UnsafeRegister(static (obj, cancellationToken) =>
                    {
                        (ResettableValueTaskSource thisRef, object? target) = ((ResettableValueTaskSource, object?))obj!;
                        lock (thisRef)
                        {
                            thisRef._cancelledToken = cancellationToken;
                        }
                        thisRef._cancellationAction?.Invoke(target);
                    }, (this, keepAlive));
                }
            }
 
            State state = _state;
 
            // None: prepare for the actual operation happening and transition to Awaiting.
            if (state == State.None)
            {
                // Keep alive the caller object until the result is read from the task.
                // Used for keeping caller alive during async interop calls.
                if (keepAlive is not null)
                {
                    Debug.Assert(!_keepAlive.IsAllocated);
                    _keepAlive = GCHandle.Alloc(keepAlive);
                }
 
                _state = State.Awaiting;
            }
            // None, Ready, Completed: return the current task.
            if (state is State.None or State.Ready or State.Completed)
            {
                // Remember that the value task with the current version is being given out.
                _hasWaiter = true;
                valueTask = new ValueTask(this, _valueTaskSource.Version);
                return true;
            }
 
            // Awaiting: forbidden concurrent call.
            valueTask = default;
            return false;
        }
    }
 
    /// <summary>
    /// Gets a <see cref="Task"/> that will transition to a completed state with the last transition of this source, i.e. into <see cref="State.Completed"/>.
    /// </summary>
    /// <returns>The <see cref="Task"/> that will transition to a completed state with the last transition of this source.</returns>
    public Task GetFinalTask(object? keepAlive)
    {
        lock (this)
        {
            return _finalTaskSource.GetTask(keepAlive);
        }
    }
 
    private bool TryComplete(Exception? exception, bool final)
    {
        // Dispose the cancellation registration before completing the task, so that it cannot run after the awaiting method returned.
        // Dispose must be done outside of lock since it will wait on pending cancellation callbacks that can hold the lock from another thread.
        CancellationTokenRegistration cancellationRegistration = default;
        lock (this)
        {
            cancellationRegistration = _cancellationRegistration;
            _cancellationRegistration = default;
        }
        cancellationRegistration.Dispose();
 
        lock (this)
        {
            try
            {
                State state = _state;
 
                // Completed: nothing to do.
                if (state == State.Completed)
                {
                    return false;
                }
 
                // The task was non-finally completed without having anyone awaiting on it.
                // In such case, discard the temporary result and replace it with this final completion.
                if (state == State.Ready && !_hasWaiter && final)
                {
                    _valueTaskSource.Reset();
                    state = State.None;
                }
 
                // If the _valueTaskSource has already been set, we don't want to lose the result by overwriting it.
                // So keep it as is and store the result in _finalTaskSource.
                if (state is State.None or State.Awaiting)
                {
                    _state = final ? State.Completed : State.Ready;
                }
 
                // Unblock the current task source and in case of a final also the final task source.
                if (exception is not null)
                {
                    // Set up the exception stack trace for the caller.
                    exception = exception.StackTrace is null ? ExceptionDispatchInfo.SetCurrentStackTrace(exception) : exception;
                    if (state is State.None or State.Awaiting)
                    {
                        _valueTaskSource.SetException(exception);
                    }
                }
                else
                {
                    if (state is State.None or State.Awaiting)
                    {
                        _valueTaskSource.SetResult(final);
                    }
                }
                if (final)
                {
                    if (_finalTaskSource.TryComplete(exception))
                    {
                        // Signal the final task only if we don't have another result in the value task source.
                        // In that case, the final task will be signalled after the value task result is retrieved.
                        if (state != State.Ready)
                        {
                            _finalTaskSource.TrySignal(out _);
                        }
                        return true;
                    }
                    return false;
                }
                return state != State.Ready;
            }
            finally
            {
                // Un-root the kept alive object in all cases.
                if (_keepAlive.IsAllocated)
                {
                    _keepAlive.Free();
                }
            }
        }
    }
 
    /// <summary>
    /// Tries to transition from <see cref="State.Awaiting"/> to either <see cref="State.Ready"/> or <see cref="State.Completed"/>, depending on the value of <paramref name="final"/>.
    /// Only the first call (with either value for <paramref name="final"/>) is able to do that. I.e.: <c>TrySetResult()</c> followed by <c>TrySetResult(true)</c> will both return <c>true</c>.
    /// </summary>
    /// <param name="final">Whether this is the final transition to <see cref="State.Completed" /> or just a transition into <see cref="State.Ready"/> from which the task source can be reset back to <see cref="State.None"/>.</param>
    /// <returns><c>true</c> if this is the first call that set the result; otherwise, <c>false</c>.</returns>
    public bool TrySetResult(bool final = false)
    {
        return TryComplete(null, final);
    }
 
    /// <summary>
    /// Tries to transition from <see cref="State.Awaiting"/> to either <see cref="State.Ready"/> or <see cref="State.Completed"/>, depending on the value of <paramref name="final"/>.
    /// Only the first call is able to do that with the exception of <c>TrySetResult()</c> followed by <c>TrySetResult(true)</c>, which will both return <c>true</c>.
    /// </summary>
    /// <param name="final">Whether this is the final transition to <see cref="State.Completed" /> or just a transition into <see cref="State.Ready"/> from which the task source can be reset back to <see cref="State.None"/>.</param>
    /// <param name="exception">The exception to set as a result of the value task.</param>
    /// <returns><c>true</c> if this is the first call that set the result; otherwise, <c>false</c>.</returns>
    public bool TrySetException(Exception exception, bool final = false)
    {
        return TryComplete(exception, final);
    }
 
    ValueTaskSourceStatus IValueTaskSource.GetStatus(short token)
        => _valueTaskSource.GetStatus(token);
 
    void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
        => _valueTaskSource.OnCompleted(continuation, state, token, flags);
 
    void IValueTaskSource.GetResult(short token)
    {
        try
        {
            _cancelledToken.ThrowIfCancellationRequested();
            _valueTaskSource.GetResult(token);
        }
        finally
        {
            lock (this)
            {
                State state = _state;
 
                _hasWaiter = false;
                _cancelledToken = default;
 
                if (state == State.Ready)
                {
                    _valueTaskSource.Reset();
                    _state = State.None;
 
                    // Propagate the _finalTaskSource result into _valueTaskSource if completed.
                    if (_finalTaskSource.TrySignal(out Exception? exception))
                    {
                        _state = State.Completed;
 
                        if (exception is not null)
                        {
                            _valueTaskSource.SetException(exception);
                        }
                        else
                        {
                            _valueTaskSource.SetResult(true);
                        }
                    }
                    else
                    {
                        _state = State.None;
                    }
                }
            }
        }
    }
 
    /// <summary>
    /// It remembers the result from <see cref="TryComplete"/> and propagates it to <see cref="_finalTaskSource"/> only after <see cref="TrySignal"/> is called.
    /// Effectively allowing to separate setting of the result from task completion, which is necessary when the resettable portion of the value task source needs to consumed first.
    /// </summary>
    private struct FinalTaskSource
    {
        private TaskCompletionSource? _finalTaskSource;
        private bool _isCompleted;
        private bool _isSignaled;
        private Exception? _exception;
 
        public FinalTaskSource()
        {
            _finalTaskSource = null;
            _isCompleted = false;
            _isSignaled = false;
            _exception = null;
        }
 
        public Task GetTask(object? keepAlive)
        {
            if (_finalTaskSource is null)
            {
                if (_isSignaled)
                {
                    return _exception is null
                        ? Task.CompletedTask
                        : Task.FromException(_exception);
                }
 
                _finalTaskSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
                if (!_isCompleted)
                {
                    GCHandle handle = GCHandle.Alloc(keepAlive);
                    _finalTaskSource.Task.ContinueWith(static (_, state) =>
                    {
                        ((GCHandle)state!).Free();
                    }, handle, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
                }
            }
            return _finalTaskSource.Task;
        }
 
        public bool TryComplete(Exception? exception = null)
        {
            if (_isCompleted)
            {
                return false;
            }
 
            _exception = exception;
            _isCompleted = true;
            return true;
        }
 
        public bool TrySignal(out Exception? exception)
        {
            if (!_isCompleted)
            {
                exception = default;
                return false;
            }
 
            if (_exception is not null)
            {
                _finalTaskSource?.SetException(_exception);
            }
            else
            {
                _finalTaskSource?.SetResult();
            }
 
            exception = _exception;
            _isSignaled = true;
            return true;
        }
    }
}