File: src\Workspaces\SharedUtilitiesAndExtensions\Compiler\Core\Utilities\IDictionaryExtensions.cs
Web Access
Project: src\src\Workspaces\Core\Portable\Microsoft.CodeAnalysis.Workspaces.csproj (Microsoft.CodeAnalysis.Workspaces)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.PooledObjects;
 
namespace Roslyn.Utilities;
 
internal static class IDictionaryExtensions
{
    // Copied from ConcurrentDictionary since IDictionary doesn't have this useful method
    public static V GetOrAdd<K, V>(this IDictionary<K, V> dictionary, K key, Func<K, V> function)
        where K : notnull
    {
        if (!dictionary.TryGetValue(key, out var value))
        {
            value = function(key);
            dictionary.Add(key, value);
        }
 
        return value;
    }
 
    public static V GetOrAdd<K, V, TArg>(this IDictionary<K, V> dictionary, K key, Func<K, TArg, V> function, TArg arg)
        where K : notnull
    {
        if (!dictionary.TryGetValue(key, out var value))
        {
            value = function(key, arg);
            dictionary.Add(key, value);
        }
 
        return value;
    }
 
    public static TValue? GetValueOrDefault<TKey, TValue>(this IDictionary<TKey, TValue> dictionary, TKey key)
        where TKey : notnull
    {
        if (dictionary.TryGetValue(key, out var value))
        {
            return value;
        }
 
        return default!;
    }
 
    [return: NotNullIfNotNull(nameof(defaultValue))]
    public static TValue? GetValueOrDefault<TKey, TValue>(this Dictionary<TKey, TValue> dictionary, TKey key, TValue? defaultValue)
        where TKey : notnull
    {
        if (dictionary.TryGetValue(key, out var value))
        {
            return value;
        }
 
        return defaultValue;
    }
 
    public static void MultiAdd<TKey, TValue, TCollection>(this IDictionary<TKey, TCollection> dictionary, TKey key, TValue value)
        where TKey : notnull
        where TCollection : ICollection<TValue>, new()
    {
        if (!dictionary.TryGetValue(key, out var collection))
        {
            collection = new TCollection();
            dictionary.Add(key, collection);
        }
 
        collection.Add(value);
    }
 
    public static void MultiAdd<TKey, TValue>(this IDictionary<TKey, ArrayBuilder<TValue>> dictionary, TKey key, TValue value)
        where TKey : notnull
    {
        if (!dictionary.TryGetValue(key, out var builder))
        {
            builder = ArrayBuilder<TValue>.GetInstance();
            dictionary.Add(key, builder);
        }
 
        builder.Add(value);
    }
 
    public static void MultiAddRange<TKey, TValue>(this IDictionary<TKey, ArrayBuilder<TValue>> dictionary, TKey key, IEnumerable<TValue> values)
        where TKey : notnull
    {
        if (!dictionary.TryGetValue(key, out var builder))
        {
            builder = ArrayBuilder<TValue>.GetInstance();
            dictionary.Add(key, builder);
        }
 
        builder.AddRange(values);
    }
 
    public static bool MultiAdd<TKey, TValue>(this IDictionary<TKey, ImmutableHashSet<TValue>> dictionary, TKey key, TValue value, IEqualityComparer<TValue>? comparer = null)
        where TKey : notnull
    {
        if (dictionary.TryGetValue(key, out var set))
        {
            var updated = set.Add(value);
            if (set == updated)
                return false;
 
            dictionary[key] = updated;
            return true;
        }
        else
        {
            dictionary[key] = ImmutableHashSet.Create(comparer, value);
            return true;
        }
    }
 
    public static void MultiAdd<TKey, TValue>(this IDictionary<TKey, ImmutableArray<TValue>> dictionary, TKey key, TValue value)
        where TKey : notnull
        where TValue : IEquatable<TValue>
    {
        if (!dictionary.TryGetValue(key, out var existingArray))
        {
            existingArray = [];
        }
 
        dictionary[key] = existingArray.Add(value);
    }
 
    public static void MultiAdd<TKey, TValue>(this IDictionary<TKey, ImmutableArray<TValue>> dictionary, TKey key, TValue value, ImmutableArray<TValue> defaultArray)
        where TKey : notnull
        where TValue : IEquatable<TValue>
    {
        if (!dictionary.TryGetValue(key, out var existingArray))
        {
            existingArray = [];
        }
 
        dictionary[key] = existingArray.IsEmpty && value.Equals(defaultArray[0]) ? defaultArray : existingArray.Add(value);
    }
 
    public static void MultiRemove<TKey, TValue, TCollection>(this IDictionary<TKey, TCollection> dictionary, TKey key, TValue value)
        where TKey : notnull
        where TCollection : ICollection<TValue>
    {
        if (dictionary.TryGetValue(key, out var collection))
        {
            collection.Remove(value);
 
            if (collection.Count == 0)
            {
                dictionary.Remove(key);
            }
        }
    }
 
    public static ImmutableDictionary<TKey, ImmutableHashSet<TValue>> MultiRemove<TKey, TValue>(this ImmutableDictionary<TKey, ImmutableHashSet<TValue>> dictionary, TKey key, TValue value)
        where TKey : notnull
    {
        if (dictionary.TryGetValue(key, out var collection))
        {
            collection = collection.Remove(value);
            if (collection.IsEmpty)
            {
                return dictionary.Remove(key);
            }
            else
            {
                return dictionary.SetItem(key, collection);
            }
        }
 
        return dictionary;
    }
 
    /// <summary>
    /// Private implementation we can delegate to for sets.
    /// This must be a different name as overloads are not resolved based on constraints
    /// and would conflict with <see cref="MultiRemove{TKey, TValue, TCollection}(IDictionary{TKey, TCollection}, TKey, TValue)"/>
    /// </summary>
    private static void MultiRemoveSet<TKey, TValue, TSet>(this IDictionary<TKey, TSet> dictionary, TKey key, TValue value)
        where TKey : notnull
        where TSet : IImmutableSet<TValue>
    {
        if (dictionary.TryGetValue(key, out var collection))
        {
            collection = (TSet)collection.Remove(value);
            if (collection.IsEmpty())
            {
                dictionary.Remove(key);
            }
            else
            {
                dictionary[key] = collection;
            }
        }
    }
 
    public static void MultiRemove<TKey, TValue>(this IDictionary<TKey, ImmutableHashSet<TValue>> dictionary, TKey key, TValue value)
        where TKey : notnull
    {
        MultiRemoveSet(dictionary, key, value);
    }
 
    public static void MultiRemove<TKey, TValue>(this IDictionary<TKey, ImmutableSortedSet<TValue>> dictionary, TKey key, TValue value)
        where TKey : notnull
    {
        MultiRemoveSet(dictionary, key, value);
    }
 
    public static void MultiRemove<TKey, TValue>(this IDictionary<TKey, ImmutableArray<TValue>> dictionary, TKey key, TValue value)
        where TKey : notnull
    {
        if (dictionary.TryGetValue(key, out var collection))
        {
            if (collection.Length == 1 && EqualityComparer<TValue>.Default.Equals(collection[0], value))
            {
                dictionary.Remove(key);
            }
            else
            {
                dictionary[key] = collection.Remove(value);
            }
        }
    }
}