File: Internal\Infrastructure\KestrelConnection.cs
Web Access
Project: src\src\Servers\Kestrel\Core\src\Microsoft.AspNetCore.Server.Kestrel.Core.csproj (Microsoft.AspNetCore.Server.Kestrel.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.Extensions.Logging;
 
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
 
internal abstract class KestrelConnection : IConnectionHeartbeatFeature, IConnectionCompleteFeature, IConnectionLifetimeNotificationFeature, IConnectionMetricsContextFeature
{
    private List<(Action<object> handler, object state)>? _heartbeatHandlers;
    private readonly object _heartbeatLock = new object();
 
    private Stack<KeyValuePair<Func<object, Task>, object>>? _onCompleted;
    private bool _completed;
 
    private readonly CancellationTokenSource _connectionClosingCts = new CancellationTokenSource();
    private readonly TaskCompletionSource _completionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    protected readonly long _id;
    protected readonly ServiceContext _serviceContext;
    protected readonly TransportConnectionManager _transportConnectionManager;
 
    public KestrelConnection(long id,
                             ServiceContext serviceContext,
                             TransportConnectionManager transportConnectionManager,
                             KestrelTrace logger,
                             ConnectionMetricsContext connectionMetricsContext)
    {
        _id = id;
        _serviceContext = serviceContext;
        _transportConnectionManager = transportConnectionManager;
        Logger = logger;
        MetricsContext = connectionMetricsContext;
        ConnectionClosedRequested = _connectionClosingCts.Token;
    }
 
    protected KestrelTrace Logger { get; }
 
    public ConnectionMetricsContext MetricsContext { get; set; }
    public CancellationToken ConnectionClosedRequested { get; set; }
    public Task ExecutionTask => _completionTcs.Task;
 
    public void TickHeartbeat()
    {
        lock (_heartbeatLock)
        {
            if (_heartbeatHandlers == null)
            {
                return;
            }
 
            foreach (var (handler, state) in _heartbeatHandlers)
            {
                handler(state);
            }
        }
    }
 
    public abstract BaseConnectionContext TransportConnection { get; }
 
    public void OnHeartbeat(Action<object> action, object state)
    {
        lock (_heartbeatLock)
        {
            if (_heartbeatHandlers == null)
            {
                _heartbeatHandlers = new List<(Action<object> handler, object state)>();
            }
 
            _heartbeatHandlers.Add((action, state));
        }
    }
 
    void IConnectionCompleteFeature.OnCompleted(Func<object, Task> callback, object state)
    {
        if (_completed)
        {
            throw new InvalidOperationException("The connection is already complete.");
        }
 
        if (_onCompleted == null)
        {
            _onCompleted = new Stack<KeyValuePair<Func<object, Task>, object>>();
        }
        _onCompleted.Push(new KeyValuePair<Func<object, Task>, object>(callback, state));
    }
 
    public Task FireOnCompletedAsync()
    {
        if (_completed)
        {
            throw new InvalidOperationException("The connection is already complete.");
        }
 
        _completed = true;
        var onCompleted = _onCompleted;
 
        if (onCompleted == null || onCompleted.Count == 0)
        {
            return Task.CompletedTask;
        }
 
        return CompleteAsyncMayAwait(onCompleted);
    }
 
    private Task CompleteAsyncMayAwait(Stack<KeyValuePair<Func<object, Task>, object>> onCompleted)
    {
        while (onCompleted.TryPop(out var entry))
        {
            try
            {
                var task = entry.Key.Invoke(entry.Value);
                if (!task.IsCompletedSuccessfully)
                {
                    return CompleteAsyncAwaited(task, onCompleted);
                }
            }
            catch (Exception ex)
            {
                Logger.LogError(ex, "An error occurred running an IConnectionCompleteFeature.OnCompleted callback.");
            }
        }
 
        return Task.CompletedTask;
    }
 
    private async Task CompleteAsyncAwaited(Task currentTask, Stack<KeyValuePair<Func<object, Task>, object>> onCompleted)
    {
        try
        {
            await currentTask;
        }
        catch (Exception ex)
        {
            Logger.LogError(ex, "An error occurred running an IConnectionCompleteFeature.OnCompleted callback.");
        }
 
        while (onCompleted.TryPop(out var entry))
        {
            try
            {
                await entry.Key.Invoke(entry.Value);
            }
            catch (Exception ex)
            {
                Logger.LogError(ex, "An error occurred running an IConnectionCompleteFeature.OnCompleted callback.");
            }
        }
    }
 
    public void RequestClose()
    {
        try
        {
            _connectionClosingCts.Cancel();
        }
        catch (ObjectDisposedException)
        {
            // There's a race where the token could be disposed
            // swallow the exception and no-op
        }
    }
 
    public void Complete()
    {
        _completionTcs.TrySetResult();
 
        _connectionClosingCts.Dispose();
    }
 
    protected IDisposable? BeginConnectionScope(BaseConnectionContext connectionContext)
    {
        if (Logger.IsEnabled(LogLevel.Critical))
        {
            return Logger.BeginScope(new ConnectionLogScope(connectionContext.ConnectionId));
        }
 
        return null;
    }
}