File: Internal\SocketAwaitableEventArgs.cs
Web Access
Project: src\src\Servers\Kestrel\Transport.Sockets\src\Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.csproj (Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.IO.Pipelines;
using System.Net.Sockets;
using System.Threading.Tasks.Sources;
 
namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal;
 
// A slimmed down version of https://github.com/dotnet/runtime/blob/82ca681cbac89d813a3ce397e0c665e6c051ed67/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs#L798 that
// 1. Doesn't support any custom scheduling other than the PipeScheduler (no sync context, no task scheduler)
// 2. Doesn't do ValueTask validation using the token
// 3. Doesn't support usage outside of async/await (doesn't try to capture and restore the execution context)
// 4. Doesn't use cancellation tokens
internal class SocketAwaitableEventArgs : SocketAsyncEventArgs, IValueTaskSource<SocketOperationResult>
{
    private static readonly Action<object?> _continuationCompleted = _ => { };
 
    private readonly PipeScheduler _ioScheduler;
 
    // There are places where we read the _continuation field and then read some other state which we assume to be consistent
    // with the value we read in _continuation. Without a fence, those secondary reads could be reordered with respect to the first.
    // https://github.com/dotnet/runtime/pull/84432
    // https://github.com/dotnet/aspnetcore/issues/50623
    private volatile Action<object?>? _continuation;
 
    public SocketAwaitableEventArgs(PipeScheduler ioScheduler)
        : base(unsafeSuppressExecutionContextFlow: true)
    {
        _ioScheduler = ioScheduler;
    }
 
    protected override void OnCompleted(SocketAsyncEventArgs _)
    {
        var c = _continuation;
 
        if (c != null || (c = Interlocked.CompareExchange(ref _continuation, _continuationCompleted, null)) != null)
        {
            var continuationState = UserToken;
            UserToken = null;
            _continuation = _continuationCompleted; // in case someone's polling IsCompleted
 
            _ioScheduler.Schedule(c, continuationState);
        }
    }
 
    public SocketOperationResult GetResult(short token)
    {
        _continuation = null;
 
        if (SocketError != SocketError.Success)
        {
            return new SocketOperationResult(CreateException(SocketError));
        }
 
        return new SocketOperationResult(BytesTransferred);
    }
 
    protected static SocketException CreateException(SocketError e)
    {
        return new SocketException((int)e);
    }
 
    public ValueTaskSourceStatus GetStatus(short token)
    {
        return !ReferenceEquals(_continuation, _continuationCompleted) ? ValueTaskSourceStatus.Pending :
                SocketError == SocketError.Success ? ValueTaskSourceStatus.Succeeded :
                ValueTaskSourceStatus.Faulted;
    }
 
    public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
    {
        UserToken = state;
        var prevContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null);
        if (ReferenceEquals(prevContinuation, _continuationCompleted))
        {
            UserToken = null;
            ThreadPool.UnsafeQueueUserWorkItem(continuation, state, preferLocal: true);
        }
    }
}