File: src\Workspaces\SharedUtilitiesAndExtensions\Compiler\Core\Utilities\EnumerableConditionalWeakTable.cs
Web Access
Project: src\src\Workspaces\Core\Portable\Microsoft.CodeAnalysis.Workspaces.csproj (Microsoft.CodeAnalysis.Workspaces)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
 
namespace Roslyn.Utilities;
 
#if NET
// Can't use global alias due to generic parameters. Extension types would do.
 
internal readonly struct EnumerableConditionalWeakTable<TKey, TValue>() : IEnumerable<KeyValuePair<TKey, TValue>>
    where TKey : class
    where TValue : class
{
    private readonly ConditionalWeakTable<TKey, TValue> _table = new();
 
    public object WriteLock => _table;
 
    public bool TryGetValue(TKey key, [NotNullWhen(true)] out TValue? value)
        => _table.TryGetValue(key, out value);
 
    public void Add(TKey key, TValue value)
        => _table.Add(key, value);
 
    public void AddOrUpdate(TKey key, TValue value)
        => _table.AddOrUpdate(key, value);
 
    public bool Remove(TKey key)
        => _table.Remove(key);
 
    public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
        => ((IEnumerable<KeyValuePair<TKey, TValue>>)_table).GetEnumerator();
 
    IEnumerator IEnumerable.GetEnumerator()
        => GetEnumerator();
}
#else
internal sealed class EnumerableConditionalWeakTable<TKey, TValue> : IEnumerable<KeyValuePair<TKey, TValue>>
    where TKey : class
    where TValue : class
{
    private sealed class Box(TKey key, TValue value)
    {
        public readonly TKey Key = key;
        public readonly TValue Value = value;
    }
 
    private readonly ConditionalWeakTable<TKey, Box> _table = new();
    private ImmutableList<WeakReference<Box>> _items = [];
 
    public object WriteLock => _table;
 
    public bool TryGetValue(TKey key, [NotNullWhen(true)] out TValue? value)
    {
        if (_table.TryGetValue(key, out var box))
        {
            value = box.Value;
            return true;
        }
 
        value = null;
        return false;
    }
 
    public void Add(TKey key, TValue value)
    {
        lock (WriteLock)
        {
            AddNoLock(key, value);
 
            // clean up collected objects:
            _items = _items.RemoveAll(WeakReferenceExtensions.IsNull);
        }
    }
 
    public void AddOrUpdate(TKey key, TValue value)
    {
        lock (WriteLock)
        {
            _ = RemoveNoLock(key);
            AddNoLock(key, value);
        }
    }
 
    public bool Remove(TKey key)
    {
        lock (WriteLock)
        {
            return RemoveNoLock(key);
        }
    }
 
    private void AddNoLock(TKey key, TValue value)
    {
        var box = new Box(key, value);
        _table.Add(key, box);
        _items = _items.Add(new WeakReference<Box>(box));
    }
 
    private bool RemoveNoLock(TKey key)
    {
        if (!_table.TryGetValue(key, out var box))
        {
            return false;
        }
 
        Contract.ThrowIfFalse(_table.Remove(key));
        _items = _items.RemoveAll(item => !item.TryGetTarget(out var target) || ReferenceEquals(target, box));
        return true;
    }
 
    public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
    {
        foreach (var item in _items)
        {
            if (item.TryGetTarget(out var box))
            {
                yield return KeyValuePairUtil.Create(box.Key, box.Value);
            }
        }
    }
 
    IEnumerator IEnumerable.GetEnumerator()
        => GetEnumerator();
}
#endif