File: src\Workspaces\SharedUtilitiesAndExtensions\Compiler\Core\Utilities\BidirectionalMap.cs
Web Access
Project: src\src\Workspaces\Core\Portable\Microsoft.CodeAnalysis.Workspaces.csproj (Microsoft.CodeAnalysis.Workspaces)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
 
namespace Roslyn.Utilities;
 
internal sealed class BidirectionalMap<TKey, TValue> : IBidirectionalMap<TKey, TValue>
    where TKey : notnull
    where TValue : notnull
{
    public static readonly IBidirectionalMap<TKey, TValue> Empty =
        new BidirectionalMap<TKey, TValue>(ImmutableDictionary.Create<TKey, TValue>(), ImmutableDictionary.Create<TValue, TKey>());
 
    private readonly ImmutableDictionary<TKey, TValue> _forwardMap;
    private readonly ImmutableDictionary<TValue, TKey> _backwardMap;
 
    public BidirectionalMap(IEnumerable<KeyValuePair<TKey, TValue>> pairs, IEqualityComparer<TKey>? keyComparer = null, IEqualityComparer<TValue>? valueComparer = null)
        : this(forwardMap: ImmutableDictionary.CreateRange(keyComparer, pairs),
               backwardMap: ImmutableDictionary.CreateRange(valueComparer, pairs.Select(static p => KeyValuePairUtil.Create(p.Value, p.Key))))
    {
    }
 
    public BidirectionalMap(IEnumerable<(TKey key, TValue value)> pairs, IEqualityComparer<TKey>? keyComparer = null, IEqualityComparer<TValue>? valueComparer = null)
        : this(forwardMap: ImmutableDictionary.CreateRange(keyComparer, pairs.Select(static p => KeyValuePairUtil.Create(p.key, p.value))),
               backwardMap: ImmutableDictionary.CreateRange(valueComparer, pairs.Select(static p => KeyValuePairUtil.Create(p.value, p.key))))
    {
    }
 
    private BidirectionalMap(ImmutableDictionary<TKey, TValue> forwardMap, ImmutableDictionary<TValue, TKey> backwardMap)
    {
        _forwardMap = forwardMap;
        _backwardMap = backwardMap;
    }
 
    public bool TryGetValue(TKey key, [NotNullWhen(true)] out TValue? value)
        => _forwardMap.TryGetValue(key, out value);
 
    public bool TryGetKey(TValue value, [NotNullWhen(true)] out TKey? key)
        => _backwardMap.TryGetValue(value, out key);
 
    public bool ContainsKey(TKey key)
        => _forwardMap.ContainsKey(key);
 
    public bool ContainsValue(TValue value)
        => _backwardMap.ContainsKey(value);
 
    public IBidirectionalMap<TKey, TValue> RemoveKey(TKey key)
    {
        if (!_forwardMap.TryGetValue(key, out var value))
        {
            return this;
        }
 
        return new BidirectionalMap<TKey, TValue>(
            _forwardMap.Remove(key),
            _backwardMap.Remove(value));
    }
 
    public IBidirectionalMap<TKey, TValue> RemoveValue(TValue value)
    {
        if (!_backwardMap.TryGetValue(value, out var key))
        {
            return this;
        }
 
        return new BidirectionalMap<TKey, TValue>(
            _forwardMap.Remove(key),
            _backwardMap.Remove(value));
    }
 
    public IBidirectionalMap<TKey, TValue> Add(TKey key, TValue value)
    {
        return new BidirectionalMap<TKey, TValue>(
            _forwardMap.Add(key, value),
            _backwardMap.Add(value, key));
    }
 
    public IEnumerable<TKey> Keys => _forwardMap.Keys;
 
    public IEnumerable<TValue> Values => _backwardMap.Keys;
 
    public bool IsEmpty
    {
        get
        {
            return _backwardMap.Count == 0;
        }
    }
 
    public int Count
    {
        get
        {
            Debug.Assert(_forwardMap.Count == _backwardMap.Count);
            return _backwardMap.Count;
        }
    }
 
    public TValue? GetValueOrDefault(TKey key)
    {
        if (TryGetValue(key, out var result))
        {
            return result;
        }
 
        return default;
    }
 
    public TKey? GetKeyOrDefault(TValue value)
    {
        if (TryGetKey(value, out var result))
        {
            return result;
        }
 
        return default;
    }
 
    public TValue this[TKey key]
    {
        get
        {
            if (TryGetValue(key, out var result))
            {
                return result;
            }
 
            throw new KeyNotFoundException();
        }
    }
 
    public TKey this[TValue value]
    {
        get
        {
            if (TryGetKey(value, out var result))
            {
                return result;
            }
 
            throw new KeyNotFoundException();
        }
    }
}