File: src\Servers\Kestrel\shared\test\TestApplicationErrorLogger.cs
Web Access
Project: src\src\Servers\Kestrel\Transport.Quic\test\Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Tests.csproj (Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.Extensions.Logging;
 
namespace Microsoft.AspNetCore.InternalTesting;
 
public class TestApplicationErrorLogger : ILogger
{
    // Application errors are logged using 13 as the eventId.
    private const int ApplicationErrorEventId = 13;
 
    private Func<LogMessage, bool> _messageFilter;
    private TaskCompletionSource<LogMessage> _messageFilterTcs;
 
    public List<Type> IgnoredExceptions { get; } = new List<Type>();
 
    public bool ThrowOnCriticalErrors { get; set; } = true;
 
    public bool ThrowOnUngracefulShutdown { get; set; } = true;
 
    public ConcurrentQueue<LogMessage> Messages { get; } = new ConcurrentQueue<LogMessage>();
 
    public ConcurrentQueue<object> Scopes { get; } = new ConcurrentQueue<object>();
 
    public int TotalErrorsLogged => Messages.Count(message => message.LogLevel == LogLevel.Error);
 
    public int CriticalErrorsLogged => Messages.Count(message => message.LogLevel == LogLevel.Critical);
 
    public int ApplicationErrorsLogged => Messages.Count(message => message.EventId.Id == ApplicationErrorEventId);
 
    public Task<LogMessage> WaitForMessage(Func<LogMessage, bool> messageFilter)
    {
        if (_messageFilterTcs != null)
        {
            throw new InvalidOperationException($"{nameof(WaitForMessage)} cannot be called concurrently.");
        }
 
        _messageFilterTcs = new TaskCompletionSource<LogMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
        _messageFilter = messageFilter;
 
        return _messageFilterTcs.Task;
    }
 
    public IDisposable BeginScope<TState>(TState state)
    {
        Scopes.Enqueue(state);
 
        return new Disposable(() => { Scopes.TryDequeue(out _); });
    }
 
    public bool IsEnabled(LogLevel logLevel)
    {
        return true;
    }
 
    public void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func<TState, Exception, string> formatter)
    {
        var exceptionIsIgnored = IgnoredExceptions.Contains(exception?.GetType());
 
        if (logLevel == LogLevel.Critical && ThrowOnCriticalErrors && !exceptionIsIgnored)
        {
            var log = $"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception}";
 
            Console.WriteLine(log);
 
            if (logLevel == LogLevel.Critical && ThrowOnCriticalErrors && !exceptionIsIgnored)
            {
                throw new Exception($"Unexpected critical error. {log}", exception);
            }
        }
 
        // Fail tests where not all the connections close during server shutdown.
        if (ThrowOnUngracefulShutdown &&
            ((eventId.Id == 16 && eventId.Name == nameof(KestrelTrace.NotAllConnectionsClosedGracefully)) ||
             (eventId.Id == 21 && eventId.Name == nameof(KestrelTrace.NotAllConnectionsAborted))))
        {
            var log = $"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception}";
            throw new Exception($"Shutdown failure. {log}");
        }
 
        // We don't use nameof here because this is logged by the transports and we don't know which one is
        // referenced in this shared source file.
        if (eventId.Id == 14 && eventId.Name == "ConnectionError")
        {
            var log = $"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception}";
            throw new Exception($"Unexpected connection error. {log}");
        }
 
        var logMessage = new LogMessage
        {
            LogLevel = logLevel,
            EventId = eventId,
            Exception = exception,
            Message = formatter(state, exception)
        };
 
        Messages.Enqueue(logMessage);
 
        if (_messageFilter?.Invoke(logMessage) == true)
        {
            var localTcs = _messageFilterTcs;
            // need to set tcs to null before calling TrySetResult
            // to prevent the next WaitForMessage possibly throwing for a non-null tcs
            _messageFilterTcs = null;
            localTcs.TrySetResult(logMessage);
        }
    }
 
    public class LogMessage
    {
        public LogLevel LogLevel { get; set; }
        public EventId EventId { get; set; }
        public Exception Exception { get; set; }
        public string Message { get; set; }
    }
}