File: Interop\CleanableWeakComHandleTable.cs
Web Access
Project: src\src\VisualStudio\Core\Def\Microsoft.VisualStudio.LanguageServices_pxr0p0dn_wpftmp.csproj (Microsoft.VisualStudio.LanguageServices)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Editor.Shared.Extensions;
using Microsoft.CodeAnalysis.Editor.Shared.Utilities;
using Microsoft.CodeAnalysis.Shared.TestHooks;
using Microsoft.VisualStudio.LanguageServices.Implementation.Utilities;
 
namespace Microsoft.VisualStudio.LanguageServices.Implementation.Interop;
 
/// <summary>
/// Special collection for storing a table of COM objects weakly that provides
/// logic for cleaning up dead references in a time-sliced way. Public members of this
/// collection are affinitized to the foreground thread.
/// </summary>
internal sealed class CleanableWeakComHandleTable<TKey, TValue> where TValue : class
{
    private const int DefaultCleanUpThreshold = 25;
    private static readonly TimeSpan s_defaultCleanUpTimeSlice = TimeSpan.FromMilliseconds(15);
 
    private readonly Dictionary<TKey, WeakComHandle<TValue, TValue>> _table;
    private readonly HashSet<TKey> _deadKeySet;
    private readonly IThreadingContext _threadingContext;
 
    /// <summary>
    /// The upper limit of items that the collection will store before clean up is recommended.
    /// </summary>
    public int CleanUpThreshold { get; }
 
    /// <summary>
    /// The amount of time that can pass during clean up it returns.
    /// </summary>
    public TimeSpan CleanUpTimeSlice { get; }
 
    private int _itemsAddedSinceLastCleanUp;
 
    public bool NeedsCleanUp { get; private set; }
 
    public CleanableWeakComHandleTable(IThreadingContext threadingContext, int? cleanUpThreshold = null, TimeSpan? cleanUpTimeSlice = null)
    {
        _table = [];
        _deadKeySet = [];
        _threadingContext = threadingContext;
        CleanUpThreshold = cleanUpThreshold ?? DefaultCleanUpThreshold;
        CleanUpTimeSlice = cleanUpTimeSlice ?? s_defaultCleanUpTimeSlice;
    }
 
    /// <summary>
    /// Cleans up references to dead objects in the table. This operation will yield to other foreground operations
    /// any time execution exceeds <see cref="CleanUpTimeSlice"/>.
    /// </summary>
    public async Task CleanUpDeadObjectsAsync(IAsynchronousOperationListener listener)
    {
        using var _ = listener.BeginAsyncOperation(nameof(CleanUpDeadObjectsAsync));
 
        Debug.Assert(_threadingContext.JoinableTaskContext.IsOnMainThread, "This method is optimized for cases where calls do not yield before checking _needsCleanUp.");
 
        await _threadingContext.JoinableTaskFactory.SwitchToMainThreadAsync(_threadingContext.DisposalToken);
 
        if (!NeedsCleanUp)
        {
            return;
        }
 
        // Immediately mark as not needing cleanup; this operation will clean up the table by the time it returns.
        NeedsCleanUp = false;
 
        var timeSlice = new TimeSlice(CleanUpTimeSlice);
 
        await CollectDeadKeysAsync().ConfigureAwait(true);
        await RemoveDeadKeysAsync().ConfigureAwait(true);
        return;
 
        // Local functions
        async Task CollectDeadKeysAsync()
        {
            // This method returns after making a complete pass enumerating the elements of _table without finding
            // any entries that are not alive. If a pass exceeds the allowed time slice after finding one or more
            // dead entries, the pass yields before processing the elements found so far and restarting the
            // enumeration.
            //
            // ⚠ This method may interleave with other asynchronous calls to CleanUpDeadObjectsAsync.
            var cleanUpEnumerator = _table.GetEnumerator();
            while (cleanUpEnumerator.MoveNext())
            {
                var pair = cleanUpEnumerator.Current;
                if (!pair.Value.IsAlive())
                {
                    _deadKeySet.Add(pair.Key);
 
                    if (timeSlice.IsOver)
                    {
                        // Yield before processing items found so far.
                        await ResetTimeSliceAsync().ConfigureAwait(true);
 
                        // Process items found prior to exceeding the time slice. Due to interleaving, it is
                        // possible for this call to process items found by another asynchronous call to
                        // CollectDeadKeysAsync, or for another asynchronous call to RemoveDeadKeysAsync to process
                        // all items prior to this call.
                        await RemoveDeadKeysAsync().ConfigureAwait(true);
 
                        // Obtain a new enumerator since the previous one may be invalidated.
                        cleanUpEnumerator = _table.GetEnumerator();
                    }
                }
            }
        }
 
        async Task RemoveDeadKeysAsync()
        {
            while (_deadKeySet.Count > 0)
            {
                // Fully process one item from _deadKeySet before the possibility of yielding
                var key = _deadKeySet.First();
 
                _deadKeySet.Remove(key);
 
                Debug.Assert(_table.ContainsKey(key), "Key not found in table.");
                _table.Remove(key);
 
                if (timeSlice.IsOver)
                {
                    await ResetTimeSliceAsync().ConfigureAwait(true);
                }
            }
        }
 
        async Task ResetTimeSliceAsync()
        {
            await listener.Delay(DelayTimeSpan.NearImmediate, _threadingContext.DisposalToken).ConfigureAwait(true);
            timeSlice = new TimeSlice(CleanUpTimeSlice);
        }
    }
 
    public void Add(TKey key, TValue value)
    {
        _threadingContext.ThrowIfNotOnUIThread();
 
        if (value == null)
        {
            throw new ArgumentNullException(nameof(value));
        }
 
        if (_table.ContainsKey(key))
        {
            throw new InvalidOperationException($"Key already exists in table: {(key != null ? key.ToString() : "<null>")}.");
        }
 
        _itemsAddedSinceLastCleanUp++;
        if (_itemsAddedSinceLastCleanUp >= CleanUpThreshold)
        {
            NeedsCleanUp = true;
            _itemsAddedSinceLastCleanUp = 0;
        }
 
        _table.Add(key, new WeakComHandle<TValue, TValue>(value));
    }
 
    public TValue Remove(TKey key)
    {
        _threadingContext.ThrowIfNotOnUIThread();
 
        _deadKeySet.Remove(key);
 
        if (_table.TryGetValue(key, out var handle))
        {
            _table.Remove(key);
            return handle.ComAggregateObject;
        }
 
        return null;
    }
 
    public bool ContainsKey(TKey key)
    {
        _threadingContext.ThrowIfNotOnUIThread();
 
        return _table.ContainsKey(key);
    }
 
    public bool TryGetValue(TKey key, out TValue value)
    {
        _threadingContext.ThrowIfNotOnUIThread();
        if (_table.TryGetValue(key, out var handle))
        {
            value = handle.ComAggregateObject;
            return value != null;
        }
 
        value = null;
        return false;
    }
 
    public IEnumerable<TValue> Values
    {
        get
        {
            foreach (var keyValuePair in _table)
            {
                yield return keyValuePair.Value.ComAggregateObject;
            }
        }
    }
}