File: NoAssertContext.cs
Web Access
Project: src\src\System.Windows.Forms.Primitives\tests\TestUtilities\System.Windows.Forms.Primitives.TestUtilities.csproj (System.Windows.Forms.Primitives.TestUtilities)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#nullable enable
 
using System.Collections.Concurrent;
 
namespace System;
 
/// <summary>
///  Use (within a using) to eat asserts.
/// </summary>
public sealed class NoAssertContext : IDisposable
{
    // For any given thread we don't need to lock to decide how to route messages, as any messages for that
    // given thread will not happen while we're in the constructor or dispose method on that thread. That
    // means we can safely check to see if we've hooked our thread without locking (outside of using a
    // concurrent collection to make sure the collection is in a known state).
    //
    // We do, however need to lock around hooking/unhooking our custom listener to make sure that we
    // are rerouting correctly if multiple threads are creating/disposing this class concurrently.
 
    private static readonly Lock s_lock = new();
    private static bool s_hooked;
    private static bool s_hasDefaultListener;
    private static bool s_hasThrowingListener;
 
    private static readonly ConcurrentDictionary<int, int> s_suppressedThreads = new();
 
    // "Default" is the listener that terminates the process when debug assertions fail.
    private static readonly TraceListener? s_defaultListener = Trace.Listeners["Default"];
    private static readonly NoAssertListener s_noAssertListener = new();
 
    public NoAssertContext()
    {
        s_suppressedThreads.AddOrUpdate(Environment.CurrentManagedThreadId, 1, (key, oldValue) => oldValue + 1);
 
        // Lock to make sure we are hooked properly if two threads come into the constructor/dispose at the same time.
        lock (s_lock)
        {
            if (!s_hooked)
            {
                // Hook our custom listener first so we don't lose assertions from other threads when
                // we disconnect the default listener.
                Trace.Listeners.Add(s_noAssertListener);
                if (s_defaultListener is not null && Trace.Listeners.Contains(s_defaultListener))
                {
                    s_hasDefaultListener = true;
                    Trace.Listeners.Remove(s_defaultListener);
                }
 
                if (Trace.Listeners.OfType<ThrowingTraceListener>().FirstOrDefault() is { } throwingTraceListener)
                {
                    s_hasThrowingListener = true;
                    Trace.Listeners.Remove(throwingTraceListener);
                }
 
                s_hooked = true;
            }
        }
    }
 
    public void Dispose()
    {
        GC.SuppressFinalize(this);
 
        int currentThread = Environment.CurrentManagedThreadId;
        if (s_suppressedThreads.TryRemove(currentThread, out int count))
        {
            if (count > 1)
            {
                // We're in a nested assert context on a given thread, re-add with a decremented count.
                // This doesn't need to be atomic as we're currently on the thread that would care about
                // being rerouted.
                s_suppressedThreads.TryAdd(currentThread, --count);
            }
        }
 
        lock (s_lock)
        {
            if (s_hooked && s_suppressedThreads.IsEmpty)
            {
                // We're the first to hit the need to unhook. Add the default listener back first to
                // ensure we don't lose any asserts from other threads.
                if (s_hasDefaultListener)
                {
                    Trace.Listeners.Add(s_defaultListener!);
                }
 
                if (s_hasThrowingListener)
                {
                    Trace.Listeners.Add(ThrowingTraceListener.Instance);
                }
 
                Trace.Listeners.Remove(s_noAssertListener);
                s_hooked = false;
            }
        }
    }
 
#pragma warning disable CA1821 // Remove empty Finalizers
    ~NoAssertContext()
#pragma warning restore CA1821
    {
        // We need this class to be used in a using to effectively rationalize about a test.
        throw new InvalidOperationException($"Did not dispose {nameof(NoAssertContext)}");
    }
 
    private class NoAssertListener : TraceListener
    {
        public NoAssertListener()
            : base(typeof(NoAssertListener).FullName)
        {
        }
 
        private static TraceListener? DefaultListener
        {
            get
            {
                if (s_hasThrowingListener)
                    return ThrowingTraceListener.Instance;
                else if (s_hasDefaultListener)
                    return s_defaultListener;
                else
                    return null;
            }
        }
 
        public override void Fail(string? message)
        {
            if (!s_suppressedThreads.TryGetValue(Environment.CurrentManagedThreadId, out _))
            {
                DefaultListener?.Fail(message);
            }
        }
 
        public override void Fail(string? message, string? detailMessage)
        {
            if (!s_suppressedThreads.TryGetValue(Environment.CurrentManagedThreadId, out _))
            {
                DefaultListener?.Fail(message, detailMessage);
            }
        }
 
        // Write and WriteLine are virtual
 
        public override void Write(string? message)
        {
            if (!s_suppressedThreads.TryGetValue(Environment.CurrentManagedThreadId, out _))
            {
                DefaultListener?.Write(message);
            }
        }
 
        public override void WriteLine(string? message)
        {
            if (!s_suppressedThreads.TryGetValue(Environment.CurrentManagedThreadId, out _))
            {
                DefaultListener?.WriteLine(message);
            }
        }
    }
}