File: Collections\SmallDictionary.cs
Web Access
Project: src\src\Compilers\Core\Portable\Microsoft.CodeAnalysis.csproj (Microsoft.CodeAnalysis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis
{
    /// <summary>
    /// Dictionary designed to hold small number of items.
    /// Compared to the regular Dictionary, average overhead per-item is roughly the same, but 
    /// unlike regular dictionary, this one is based on an AVL tree and as such does not require 
    /// rehashing when items are added.
    /// It does require rebalancing, but that is allocation-free.
    ///
    /// Major caveats:
    ///  1) There is no Remove method. (can be added, but we do not seem to use Remove that much)
    ///  2) foreach [keys|values|pairs] may allocate a small array.
    ///  3) Performance is no longer O(1). At a certain count it becomes slower than regular Dictionary.
    ///     In comparison to regular Dictionary on my machine:
    ///        On trivial number of elements (5 or so) it is more than 2x faster.
    ///        The break even count is about 120 elements for read and 55 for write operations (with unknown initial size).
    ///        At UShort.MaxValue elements, this dictionary is 6x slower to read and 4x slower to write
    ///
    /// Generally, this dictionary is a win if number of elements is small, not known beforehand or both.
    ///
    /// If the size of the dictionary is known at creation and it is likely to contain more than 10 elements, 
    /// then regular Dictionary is a better choice.
    /// </summary>
    internal sealed class SmallDictionary<K, V> : IEnumerable<KeyValuePair<K, V>>
        where K : notnull
    {
        private AvlNode? _root;
        public readonly IEqualityComparer<K> Comparer;
 
        // https://github.com/dotnet/roslyn/issues/40344
        public static readonly SmallDictionary<K, V> Empty = new SmallDictionary<K, V>(null!);
 
        public SmallDictionary() : this(EqualityComparer<K>.Default) { }
 
        public SmallDictionary(IEqualityComparer<K> comparer)
        {
            Comparer = comparer;
        }
 
        public SmallDictionary(SmallDictionary<K, V> other, IEqualityComparer<K> comparer)
            : this(comparer)
        {
            // TODO: if comparers are same (often they are), then just need to clone the tree.
            foreach (var kv in other)
            {
                this.Add(kv.Key, kv.Value);
            }
        }
 
        private bool CompareKeys(K k1, K k2)
        {
            return Comparer.Equals(k1, k2);
        }
 
        private int GetHashCode(K k)
        {
            return Comparer.GetHashCode(k);
        }
 
        public bool TryGetValue(K key, [MaybeNullWhen(returnValue: false)] out V value)
        {
            if (_root != null)
            {
                return TryGetValue(GetHashCode(key), key, out value!);
            }
 
            value = default!;
            return false;
        }
 
        public void Add(K key, V value)
        {
            Insert(GetHashCode(key), key, value, add: true);
        }
 
        public V this[K key]
        {
            get
            {
                V value;
                if (!TryGetValue(key, out value!))
                {
                    throw new KeyNotFoundException($"Could not find key {key}");
                }
 
                return value;
            }
 
            set
            {
                this.Insert(GetHashCode(key), key, value, add: false);
            }
        }
 
        public bool ContainsKey(K key)
        {
            V value;
            return TryGetValue(key, out value!);
        }
 
        [Conditional("DEBUG")]
        internal void AssertBalanced()
        {
#if DEBUG
            AvlNode.AssertBalanced(_root);
#endif
        }
 
        private abstract class Node
        {
            public readonly K Key;
            public V Value;
 
            protected Node(K key, V value)
            {
                this.Key = key;
                this.Value = value;
            }
 
            public virtual Node? Next => null;
        }
 
        private sealed class NodeLinked : Node
        {
            public NodeLinked(K key, V value, Node next)
                : base(key, value)
            {
                this.Next = next;
            }
 
            public override Node Next { get; }
        }
 
        private sealed class AvlNodeHead : AvlNode
        {
            public Node next;
 
            public AvlNodeHead(int hashCode, K key, V value, Node next)
                : base(hashCode, key, value)
            {
                this.next = next;
            }
 
            public override Node Next => next;
        }
 
        // separate class to ensure that HashCode field 
        // is located before other AvlNode fields
        // Balance is also here for better packing of AvlNode on 64bit
        private abstract class HashedNode : Node
        {
            public readonly int HashCode;
            public sbyte Balance;
 
            protected HashedNode(int hashCode, K key, V value)
                : base(key, value)
            {
                this.HashCode = hashCode;
            }
        }
 
        private class AvlNode : HashedNode
        {
            public AvlNode? Left;
            public AvlNode? Right;
 
            public AvlNode(int hashCode, K key, V value)
                : base(hashCode, key, value)
            { }
 
#if DEBUG
            public static int AssertBalanced(AvlNode? V)
            {
                if (V == null) return 0;
 
                int a = AssertBalanced(V.Left);
                int b = AssertBalanced(V.Right);
 
                if (a - b != V.Balance ||
                    Math.Abs(a - b) >= 2)
                {
                    throw new InvalidOperationException();
                }
 
                return 1 + Math.Max(a, b);
            }
#endif
        }
 
        private bool TryGetValue(int hashCode, K key, [MaybeNullWhen(returnValue: false)] out V value)
        {
            RoslynDebug.Assert(_root is object);
            AvlNode? b = _root;
 
            do
            {
                if (b.HashCode > hashCode)
                {
                    b = b.Left;
                }
                else if (b.HashCode < hashCode)
                {
                    b = b.Right;
                }
                else
                {
                    goto hasBucket;
                }
            } while (b != null);
 
            value = default!;
            return false;
 
hasBucket:
            if (CompareKeys(b.Key, key))
            {
                value = b.Value;
                return true;
            }
 
            return GetFromList(b.Next, key, out value!);
        }
 
        private bool GetFromList(Node? next, K key, [MaybeNullWhen(returnValue: false)] out V value)
        {
            while (next != null)
            {
                if (CompareKeys(key, next.Key))
                {
                    value = next.Value;
                    return true;
                }
 
                next = next.Next;
            }
 
            value = default!;
            return false;
        }
 
        private void Insert(int hashCode, K key, V value, bool add)
        {
            AvlNode? currentNode = _root;
 
            if (currentNode == null)
            {
                _root = new AvlNode(hashCode, key, value);
                return;
            }
 
            AvlNode? currentNodeParent = null;
            AvlNode unbalanced = currentNode;
            AvlNode? unbalancedParent = null;
 
            // ====== insert new node
            // also make a note of the last unbalanced node and its parent (for rotation if needed)
            // nodes on the search path from rotation candidate downwards will change balances because of the node added
            // unbalanced node itself will become balanced or will be rotated
            // either way nodes above unbalanced do not change their balance
            for (; ; )
            {
                // schedule hk read 
                var hc = currentNode.HashCode;
 
                if (currentNode.Balance != 0)
                {
                    unbalancedParent = currentNodeParent;
                    unbalanced = currentNode;
                }
 
                if (hc > hashCode)
                {
                    if (currentNode.Left == null)
                    {
                        var previousNode = currentNode;
                        currentNode = new AvlNode(hashCode, key, value);
                        previousNode.Left = currentNode;
                        break;
                    }
 
                    currentNodeParent = currentNode;
                    currentNode = currentNode.Left;
                }
                else if (hc < hashCode)
                {
                    if (currentNode.Right == null)
                    {
                        var previousNode = currentNode;
                        currentNode = new AvlNode(hashCode, key, value);
                        previousNode.Right = currentNode;
                        break;
                    }
 
                    currentNodeParent = currentNode;
                    currentNode = currentNode.Right;
                }
                else // (p.HashCode == hashCode)
                {
                    this.HandleInsert(currentNode, currentNodeParent, key, value, add);
                    return;
                }
            }
 
            Debug.Assert(unbalanced != currentNode);
 
            // ====== update balances on the path from unbalanced downwards
            var n = unbalanced;
            do
            {
                Debug.Assert(n.HashCode != hashCode);
 
                if (n.HashCode < hashCode)
                {
                    n.Balance--;
                    n = n.Right!;
                }
                else
                {
                    n.Balance++;
                    n = n.Left!;
                }
            }
            while (n != currentNode);
 
            // ====== rotate unbalanced node if needed
            AvlNode rotated;
            var balance = unbalanced.Balance;
            if (balance == -2)
            {
                rotated = unbalanced.Right!.Balance < 0 ?
                    LeftSimple(unbalanced) :
                    LeftComplex(unbalanced);
            }
            else if (balance == 2)
            {
                rotated = unbalanced.Left!.Balance > 0 ?
                    RightSimple(unbalanced) :
                    RightComplex(unbalanced);
            }
            else
            {
                return;
            }
 
            // ===== make parent to point to rotated
            if (unbalancedParent == null)
            {
                _root = rotated;
            }
            else if (unbalanced == unbalancedParent.Left)
            {
                unbalancedParent.Left = rotated;
            }
            else
            {
                unbalancedParent.Right = rotated;
            }
        }
 
        private static AvlNode LeftSimple(AvlNode unbalanced)
        {
            RoslynDebug.Assert(unbalanced.Right is object);
            var right = unbalanced.Right;
            unbalanced.Right = right.Left;
            right.Left = unbalanced;
 
            unbalanced.Balance = 0;
            right.Balance = 0;
            return right;
        }
 
        private static AvlNode RightSimple(AvlNode unbalanced)
        {
            RoslynDebug.Assert(unbalanced.Left is object);
            var left = unbalanced.Left;
            unbalanced.Left = left.Right;
            left.Right = unbalanced;
 
            unbalanced.Balance = 0;
            left.Balance = 0;
            return left;
        }
 
        private static AvlNode LeftComplex(AvlNode unbalanced)
        {
            RoslynDebug.Assert(unbalanced.Right is object);
            RoslynDebug.Assert(unbalanced.Right.Left is object);
            var right = unbalanced.Right;
            var rightLeft = right.Left;
            right.Left = rightLeft.Right;
            rightLeft.Right = right;
            unbalanced.Right = rightLeft.Left;
            rightLeft.Left = unbalanced;
 
            var rightLeftBalance = rightLeft.Balance;
            rightLeft.Balance = 0;
 
            if (rightLeftBalance < 0)
            {
                right.Balance = 0;
                unbalanced.Balance = 1;
            }
            else
            {
                right.Balance = (sbyte)-rightLeftBalance;
                unbalanced.Balance = 0;
            }
 
            return rightLeft;
        }
 
        private static AvlNode RightComplex(AvlNode unbalanced)
        {
            RoslynDebug.Assert(unbalanced.Left != null);
            RoslynDebug.Assert(unbalanced.Left.Right != null);
            var left = unbalanced.Left;
            var leftRight = left.Right;
            left.Right = leftRight.Left;
            leftRight.Left = left;
            unbalanced.Left = leftRight.Right;
            leftRight.Right = unbalanced;
 
            var leftRightBalance = leftRight.Balance;
            leftRight.Balance = 0;
 
            if (leftRightBalance < 0)
            {
                left.Balance = 1;
                unbalanced.Balance = 0;
            }
            else
            {
                left.Balance = 0;
                unbalanced.Balance = (sbyte)-leftRightBalance;
            }
 
            return leftRight;
        }
 
        private void HandleInsert(AvlNode node, AvlNode? parent, K key, V value, bool add)
        {
            Node? currentNode = node;
            do
            {
                if (CompareKeys(currentNode.Key, key))
                {
                    if (add)
                    {
                        throw new InvalidOperationException();
                    }
 
                    currentNode.Value = value;
                    return;
                }
 
                currentNode = currentNode.Next;
            } while (currentNode != null);
 
            AddNode(node, parent, key, value);
        }
 
        private void AddNode(AvlNode node, AvlNode? parent, K key, V value)
        {
            AvlNodeHead? head = node as AvlNodeHead;
            if (head != null)
            {
                var newNext = new NodeLinked(key, value, head.next);
                head.next = newNext;
                return;
            }
 
            var newHead = new AvlNodeHead(node.HashCode, key, value, node);
            newHead.Balance = node.Balance;
            newHead.Left = node.Left;
            newHead.Right = node.Right;
 
            if (parent == null)
            {
                _root = newHead;
                return;
            }
 
            if (node == parent.Left)
            {
                parent.Left = newHead;
            }
            else
            {
                parent.Right = newHead;
            }
        }
 
        public KeyCollection Keys => new KeyCollection(this);
 
        internal readonly struct KeyCollection : IEnumerable<K>
        {
            private readonly SmallDictionary<K, V> _dict;
 
            public KeyCollection(SmallDictionary<K, V> dict)
            {
                _dict = dict;
            }
 
            public struct Enumerator
            {
                private readonly Stack<AvlNode>? _stack;
                private Node? _next;
                private Node? _current;
 
                public Enumerator(SmallDictionary<K, V> dict)
                    : this()
                {
                    var root = dict._root;
                    if (root != null)
                    {
                        // left == right only if both are nulls
                        if (root.Left == root.Right)
                        {
                            _next = root;
                        }
                        else
                        {
                            _stack = new Stack<AvlNode>(dict.HeightApprox());
                            _stack.Push(root);
                        }
                    }
                }
 
                public K Current => _current!.Key;
 
                public bool MoveNext()
                {
                    if (_next != null)
                    {
                        _current = _next;
                        _next = _next.Next;
                        return true;
                    }
 
                    if (_stack == null || _stack.Count == 0)
                    {
                        return false;
                    }
 
                    var curr = _stack.Pop();
                    _current = curr;
                    _next = curr.Next;
 
                    PushIfNotNull(curr.Left);
                    PushIfNotNull(curr.Right);
 
                    return true;
                }
 
                private void PushIfNotNull(AvlNode? child)
                {
                    if (child != null)
                    {
                        _stack!.Push(child);
                    }
                }
            }
 
            public Enumerator GetEnumerator()
            {
                return new Enumerator(_dict);
            }
 
            public class EnumerableImpl : IEnumerator<K>
            {
                private Enumerator _e;
 
                public EnumerableImpl(Enumerator e)
                {
                    _e = e;
                }
 
                K IEnumerator<K>.Current => _e.Current;
 
                void IDisposable.Dispose()
                {
                }
 
                object IEnumerator.Current => _e.Current;
 
                bool IEnumerator.MoveNext()
                {
                    return _e.MoveNext();
                }
 
                void IEnumerator.Reset()
                {
                    throw new NotSupportedException();
                }
            }
 
            IEnumerator<K> IEnumerable<K>.GetEnumerator()
            {
                return new EnumerableImpl(GetEnumerator());
            }
 
            IEnumerator IEnumerable.GetEnumerator()
            {
                throw new NotImplementedException();
            }
        }
 
        public ValueCollection Values => new ValueCollection(this);
 
        internal readonly struct ValueCollection : IEnumerable<V>
        {
            private readonly SmallDictionary<K, V> _dict;
 
            public ValueCollection(SmallDictionary<K, V> dict)
            {
                _dict = dict;
            }
 
            public struct Enumerator
            {
                private readonly Stack<AvlNode>? _stack;
                private Node? _next;
                private Node? _current;
 
                public Enumerator(SmallDictionary<K, V> dict)
                    : this()
                {
                    var root = dict._root;
                    if (root == null)
                    {
                        return;
                    }
 
                    // left == right only if both are nulls
                    if (root.Left == root.Right)
                    {
                        _next = root;
                    }
                    else
                    {
                        _stack = new Stack<AvlNode>(dict.HeightApprox());
                        _stack.Push(root);
                    }
                }
 
                public V Current => _current!.Value;
 
                public bool MoveNext()
                {
                    if (_next != null)
                    {
                        _current = _next;
                        _next = _next.Next;
                        return true;
                    }
 
                    if (_stack == null || _stack.Count == 0)
                    {
                        return false;
                    }
 
                    var curr = _stack.Pop();
                    _current = curr;
                    _next = curr.Next;
 
                    PushIfNotNull(curr.Left);
                    PushIfNotNull(curr.Right);
 
                    return true;
                }
 
                private void PushIfNotNull(AvlNode? child)
                {
                    if (child != null)
                    {
                        _stack!.Push(child);
                    }
                }
            }
 
            public Enumerator GetEnumerator()
            {
                return new Enumerator(_dict);
            }
 
            public class EnumerableImpl : IEnumerator<V>
            {
                private Enumerator _e;
 
                public EnumerableImpl(Enumerator e)
                {
                    _e = e;
                }
 
                V IEnumerator<V>.Current => _e.Current;
 
                void IDisposable.Dispose()
                {
                }
 
                object? IEnumerator.Current => _e.Current;
 
                bool IEnumerator.MoveNext()
                {
                    return _e.MoveNext();
                }
 
                void IEnumerator.Reset()
                {
                    throw new NotImplementedException();
                }
            }
 
            IEnumerator<V> IEnumerable<V>.GetEnumerator()
            {
                return new EnumerableImpl(GetEnumerator());
            }
 
            IEnumerator IEnumerable.GetEnumerator()
            {
                throw new NotImplementedException();
            }
        }
 
        public struct Enumerator
        {
            private readonly Stack<AvlNode>? _stack;
            private Node? _next;
            private Node? _current;
 
            public Enumerator(SmallDictionary<K, V> dict)
                : this()
            {
                var root = dict._root;
                if (root == null)
                {
                    return;
                }
 
                // left == right only if both are nulls
                if (root.Left == root.Right)
                {
                    _next = root;
                }
                else
                {
                    _stack = new Stack<AvlNode>(dict.HeightApprox());
                    _stack.Push(root);
                }
            }
 
            public KeyValuePair<K, V> Current => new KeyValuePair<K, V>(_current!.Key, _current!.Value);
 
            public bool MoveNext()
            {
                if (_next != null)
                {
                    _current = _next;
                    _next = _next.Next;
                    return true;
                }
 
                if (_stack == null || _stack.Count == 0)
                {
                    return false;
                }
 
                var curr = _stack.Pop();
                _current = curr;
                _next = curr.Next;
 
                PushIfNotNull(curr.Left);
                PushIfNotNull(curr.Right);
 
                return true;
            }
 
            private void PushIfNotNull(AvlNode? child)
            {
                if (child != null)
                {
                    _stack!.Push(child);
                }
            }
        }
 
        public Enumerator GetEnumerator()
        {
            return new Enumerator(this);
        }
 
        public class EnumerableImpl : IEnumerator<KeyValuePair<K, V>>
        {
            private Enumerator _e;
 
            public EnumerableImpl(Enumerator e)
            {
                _e = e;
            }
 
            KeyValuePair<K, V> IEnumerator<KeyValuePair<K, V>>.Current => _e.Current;
 
            void IDisposable.Dispose()
            {
            }
 
            object IEnumerator.Current => _e.Current;
 
            bool IEnumerator.MoveNext()
            {
                return _e.MoveNext();
            }
 
            void IEnumerator.Reset()
            {
                throw new NotImplementedException();
            }
        }
 
        IEnumerator<KeyValuePair<K, V>> IEnumerable<KeyValuePair<K, V>>.GetEnumerator()
        {
            return new EnumerableImpl(GetEnumerator());
        }
 
        IEnumerator IEnumerable.GetEnumerator()
        {
            throw new NotImplementedException();
        }
 
        private int HeightApprox()
        {
            // height is less than 1.5 * depth(leftmost node)
            var h = 0;
            var cur = _root;
            while (cur != null)
            {
                h++;
                cur = cur.Left;
            }
 
            h = h + h / 2;
            return h;
        }
    }
}