File: Shared\TestHooks\AsynchronousOperationListenerProvider.cs
Web Access
Project: src\src\Workspaces\Core\Portable\Microsoft.CodeAnalysis.Workspaces.csproj (Microsoft.CodeAnalysis.Workspaces)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.Remote;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.Shared.TestHooks;
 
/// <summary>
/// use <see cref="IAsynchronousOperationListenerProvider" /> in product code to get
/// <see cref="IAsynchronousOperationListener" /> and use
/// <see cref="AsynchronousOperationListenerProvider" /> in test to get waiter.
/// </summary>
[Shared]
[Export(typeof(IAsynchronousOperationListenerProvider))]
[Export(typeof(AsynchronousOperationListenerProvider))]
internal sealed partial class AsynchronousOperationListenerProvider : IAsynchronousOperationListenerProvider
{
    public static readonly IAsynchronousOperationListenerProvider NullProvider = new NullListenerProvider();
    public static readonly IAsynchronousOperationListener NullListener = new NullOperationListener();
 
    /// <summary>
    /// indicate whether asynchronous listener is enabled or not.
    /// it is tri-state since we want to retrieve this value, if never explicitly set, from environment variable
    /// and then cache it.
    /// we read value from environment variable (RoslynWaiterEnabled) because we want team, that doesn't have
    /// access to Roslyn code (InternalVisibleTo), can use this listener/waiter framework as well. 
    /// those team can enable this without using <see cref="AsynchronousOperationListenerProvider.Enable(bool)" /> API
    /// </summary>
    public static bool? s_enabled = null;
 
    private readonly ConcurrentDictionary<string, AsynchronousOperationListener> _singletonListeners;
    private readonly Func<string, AsynchronousOperationListener> _createCallback;
 
    /// <summary>
    /// indicate whether <see cref="AsynchronousOperationListener.TrackActiveTokens"/> is enabled or not
    /// it is tri-state since we want to retrieve this value, if never explicitly set, from environment variable
    /// and then cache it.
    /// we read value from environment variable (RoslynWaiterDiagnosticTokenEnabled) because we want team, that doesn't have
    /// access to Roslyn code (InternalVisibleTo), can use this listener/waiter framework as well. 
    /// those team can enable this without using <see cref="AsynchronousOperationListenerProvider.EnableDiagnosticTokens(bool)" /> API
    /// </summary>
    private bool? _enableDiagnosticTokens;
 
    /// <summary>
    /// Provides a default value for <see cref="_enableDiagnosticTokens"/>.
    /// </summary>
    private static bool? s_enableDiagnosticTokens;
 
    public static void Enable(bool enable)
        => Enable(enable, diagnostics: null);
 
    public static void Enable(bool enable, bool? diagnostics)
    {
        // right now, made it static so that one can enable it through reflection easy
        // but we can think of some other way
        s_enabled = enable;
        s_enableDiagnosticTokens = diagnostics;
    }
 
    [ImportingConstructor]
    [Obsolete(MefConstruction.ImportingConstructorMessage, error: true)]
    public AsynchronousOperationListenerProvider()
    {
        _singletonListeners = new ConcurrentDictionary<string, AsynchronousOperationListener>(concurrencyLevel: 2, capacity: 20);
        _createCallback = name => new AsynchronousOperationListener(name, DiagnosticTokensEnabled);
    }
 
    public IAsynchronousOperationListener GetListener(string featureName)
    {
        if (!IsEnabled)
        {
            // if listener is not enabled. it always return null listener
            return NullListener;
        }
 
        return _singletonListeners.GetOrAdd(featureName, _createCallback);
    }
 
    /// <summary>
    /// Enable or disable TrackActiveTokens for test
    /// </summary>
    public void EnableDiagnosticTokens(bool enable)
    {
        _enableDiagnosticTokens = enable;
        _singletonListeners.Values.Do(l => l.TrackActiveTokens = enable);
    }
 
    /// <summary>
    /// Get Waiters for listeners for test
    /// </summary>
    public IAsynchronousOperationWaiter GetWaiter(string featureName)
        => (IAsynchronousOperationWaiter)GetListener(featureName);
 
    /// <summary>
    /// Wait for all of the <see cref="IAsynchronousOperationWaiter"/> instances to finish their
    /// work.
    /// </summary>
    /// <remarks>
    /// This is a very handy method for debugging hangs in the unit test.  Set a break point in the 
    /// loop, dig into the waiters and see all of the active <see cref="IAsyncToken"/> values 
    /// representing the remaining work.
    /// </remarks>
    public async Task WaitAllAsync(Workspace? workspace, string[]? featureNames = null, Action? eventProcessingAction = null, TimeSpan? timeout = null)
    {
        var startTime = Stopwatch.StartNew();
        var smallTimeout = TimeSpan.FromMilliseconds(10);
 
        RemoteHostClient? remoteHostClient = null;
        if (workspace?.Services.GetService<IRemoteHostClientProvider>() is { } remoteHostClientProvider)
        {
            remoteHostClient = await remoteHostClientProvider.TryGetRemoteHostClientAsync(CancellationToken.None).ConfigureAwait(false);
        }
 
        List<Task>? tasks = null;
        while (true)
        {
            var waiters = GetCandidateWaiters(featureNames);
            tasks = waiters.Select(x => x.ExpeditedWaitAsync()).Where(t => !t.IsCompleted).ToList();
 
            if (remoteHostClient is not null)
            {
                var isCompleted = await remoteHostClient.TryInvokeAsync<IRemoteAsynchronousOperationListenerService, bool>(
                    (service, cancellationToken) => service.IsCompletedAsync(featureNames.ToImmutableArrayOrEmpty(), cancellationToken),
                    CancellationToken.None).ConfigureAwait(false);
                if (isCompleted.HasValue && !isCompleted.Value)
                {
                    tasks.Add(remoteHostClient.TryInvokeAsync<IRemoteAsynchronousOperationListenerService>(
                        (service, cancellationToken) => service.ExpeditedWaitAsync(featureNames.ToImmutableArrayOrEmpty(), cancellationToken),
                        CancellationToken.None).AsTask());
                }
            }
 
            if (tasks.Count == 0)
            {
                // no new pending tasks
                break;
            }
 
            do
            {
                // wait for all current tasks to be done for the time given
                if (Task.WaitAll([.. tasks], smallTimeout))
                {
                    // current set of tasks are done.
                    // see whether there are new tasks added while we were waiting
                    break;
                }
 
                // certain test requires some event queues to be processed
                // for waiter tasks to finish such as Dispatcher queue
                eventProcessingAction?.Invoke();
 
                // in unit test where it uses fake foreground task scheduler such as StaTaskScheduler
                // we need to yield for the scheduler to run inlined tasks. If we are not processing events, we
                // switch to the thread pool for the continuations since the yield will only let operations at the
                // same or higher priority to execute prior to the continuation.
                var continueOnCapturedContext = eventProcessingAction is object;
                await Task.Delay(smallTimeout).ConfigureAwait(continueOnCapturedContext);
 
                if (startTime.Elapsed > timeout && timeout != Timeout.InfiniteTimeSpan)
                {
                    throw new TimeoutException();
                }
            } while (true);
        }
 
        foreach (var task in tasks)
        {
            if (task.Exception != null)
            {
                throw task.Exception;
            }
        }
    }
 
    public bool HasPendingWaiter(params string[] featureNames)
    {
        var waiters = GetCandidateWaiters(featureNames);
        return waiters.Any(w => w.HasPendingWork);
    }
 
    /// <summary>
    /// Get all saved DiagnosticAsyncToken to investigate tests failure easier
    /// </summary>
    public List<AsynchronousOperationListener.DiagnosticAsyncToken> GetTokens()
        => _singletonListeners.Values.Where(l => l.TrackActiveTokens).SelectMany(l => l.ActiveDiagnosticTokens).ToList();
 
    internal static bool IsEnabled
    {
        get
        {
            if (!s_enabled.HasValue)
            {
                // if s_enabled has never been set, check environment variable to see whether it should be enabled.
                var enabled = Environment.GetEnvironmentVariable("RoslynWaiterEnabled");
                s_enabled = string.Equals(enabled, "1", StringComparison.OrdinalIgnoreCase) || string.Equals(enabled, "True", StringComparison.OrdinalIgnoreCase);
            }
 
            return s_enabled.Value;
        }
    }
 
    internal bool DiagnosticTokensEnabled
    {
        get
        {
            if (!_enableDiagnosticTokens.HasValue)
            {
                if (s_enableDiagnosticTokens.HasValue)
                {
                    _enableDiagnosticTokens = s_enableDiagnosticTokens;
                }
                else
                {
                    // if _enableDiagnosticTokens has never been set, check environment variable to see whether it should be enabled.
                    var enabled = Environment.GetEnvironmentVariable("RoslynWaiterDiagnosticTokenEnabled");
                    _enableDiagnosticTokens = string.Equals(enabled, "1", StringComparison.OrdinalIgnoreCase) || string.Equals(enabled, "True", StringComparison.OrdinalIgnoreCase);
                }
            }
 
            return _enableDiagnosticTokens.Value;
        }
    }
 
    private IEnumerable<IAsynchronousOperationWaiter> GetCandidateWaiters(string[]? featureNames)
    {
        if (featureNames == null || featureNames.Length == 0)
        {
            return _singletonListeners.Values.Cast<IAsynchronousOperationWaiter>();
        }
 
        return _singletonListeners.Where(kv => featureNames.Contains(kv.Key)).Select(kv => (IAsynchronousOperationWaiter)kv.Value);
    }
}