File: System\Collections\Immutable\ImmutableSortedSet_1.Node.cs
Web Access
Project: src\src\libraries\System.Collections.Immutable\src\System.Collections.Immutable.csproj (System.Collections.Immutable)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
 
namespace System.Collections.Immutable
{
    public sealed partial class ImmutableSortedSet<T>
    {
        /// <summary>
        /// A node in the AVL tree storing this set.
        /// </summary>
        [DebuggerDisplay("{_key}")]
        internal sealed class Node : IEnumerable<T>
        {
            /// <summary>
            /// The default empty node.
            /// </summary>
            internal static readonly Node EmptyNode = new Node();
 
            /// <summary>
            /// The key associated with this node.
            /// </summary>
            private readonly T _key = default!;
 
            /// <summary>
            /// A value indicating whether this node has been frozen (made immutable).
            /// </summary>
            /// <remarks>
            /// Nodes must be frozen before ever being observed by a wrapping collection type
            /// to protect collections from further mutations.
            /// </remarks>
            private bool _frozen;
 
            /// <summary>
            /// The depth of the tree beneath this node.
            /// </summary>
            private byte _height; // AVL tree max height <= ~1.44 * log2(maxNodes + 2)
 
            /// <summary>
            /// The number of elements contained by this subtree starting at this node.
            /// </summary>
            /// <remarks>
            /// If this node would benefit from saving 4 bytes, we could have only a few nodes
            /// scattered throughout the graph actually record the count of nodes beneath them.
            /// Those without the count could query their descendants, which would often short-circuit
            /// when they hit a node that *does* include a count field.
            /// </remarks>
            private int _count;
 
            /// <summary>
            /// The left tree.
            /// </summary>
            private Node? _left;
 
            /// <summary>
            /// The right tree.
            /// </summary>
            private Node? _right;
 
            /// <summary>
            /// Initializes a new instance of the <see cref="ImmutableSortedSet{T}.Node"/> class
            /// that is pre-frozen.
            /// </summary>
            private Node()
            {
                _frozen = true; // the empty node is *always* frozen.
                Debug.Assert(this.IsEmpty);
            }
 
            /// <summary>
            /// Initializes a new instance of the <see cref="ImmutableSortedSet{T}.Node"/> class
            /// that is not yet frozen.
            /// </summary>
            /// <param name="key">The value stored by this node.</param>
            /// <param name="left">The left branch.</param>
            /// <param name="right">The right branch.</param>
            /// <param name="frozen">Whether this node is prefrozen.</param>
            private Node(T key, Node left, Node right, bool frozen = false)
            {
                Requires.NotNull(left, nameof(left));
                Requires.NotNull(right, nameof(right));
                Debug.Assert(!frozen || (left._frozen && right._frozen));
 
                _key = key;
                _left = left;
                _right = right;
                _height = checked((byte)(1 + Math.Max(left._height, right._height)));
                _count = 1 + left._count + right._count;
                _frozen = frozen;
            }
 
            /// <summary>
            /// Gets a value indicating whether this instance is empty.
            /// </summary>
            /// <value>
            ///   <c>true</c> if this instance is empty; otherwise, <c>false</c>.
            /// </value>
            public bool IsEmpty
            {
                get { return _left == null; }
            }
 
            /// <summary>
            /// Gets the height of the tree beneath this node.
            /// </summary>
            public int Height
            {
                get { return _height; }
            }
 
            /// <summary>
            /// Gets the left branch of this node.
            /// </summary>
            public Node? Left
            {
                get { return _left; }
            }
 
            /// <summary>
            /// Gets the right branch of this node.
            /// </summary>
            public Node? Right
            {
                get { return _right; }
            }
 
            /// <summary>
            /// Gets the value represented by the current node.
            /// </summary>
            public T Value { get { return _key; } }
 
            /// <summary>
            /// Gets the number of elements contained by this subtree starting at this node.
            /// </summary>
            public int Count
            {
                get { return _count; }
            }
 
            /// <summary>
            /// Gets the key.
            /// </summary>
            internal T Key
            {
                get { return _key; }
            }
 
            /// <summary>
            /// Gets the maximum value in the collection, as defined by the comparer.
            /// </summary>
            /// <value>The maximum value in the set.</value>
            internal T? Max
            {
                get
                {
                    if (this.IsEmpty)
                    {
                        return default;
                    }
 
                    Node n = this;
                    while (!n._right!.IsEmpty)
                    {
                        n = n._right;
                    }
 
                    return n._key;
                }
            }
 
            /// <summary>
            /// Gets the minimum value in the collection, as defined by the comparer.
            /// </summary>
            /// <value>The minimum value in the set.</value>
            internal T? Min
            {
                get
                {
                    if (this.IsEmpty)
                    {
                        return default;
                    }
 
                    Node n = this;
                    while (!n._left!.IsEmpty)
                    {
                        n = n._left;
                    }
 
                    return n._key;
                }
            }
 
            /// <summary>
            /// Gets the element of the set at the given index.
            /// </summary>
            /// <param name="index">The 0-based index of the element in the set to return.</param>
            /// <returns>The element at the given position.</returns>
            internal T this[int index]
            {
                get
                {
                    Requires.Range(index >= 0 && index < this.Count, nameof(index));
                    Debug.Assert(_left != null && _right != null);
 
                    if (index < _left._count)
                    {
                        return _left[index];
                    }
 
                    if (index > _left._count)
                    {
                        return _right[index - _left._count - 1];
                    }
 
                    return _key;
                }
            }
 
            /// <summary>
            /// Gets a read-only reference to the element of the set at the given index.
            /// </summary>
            /// <param name="index">The 0-based index of the element in the set to return.</param>
            /// <returns>A read-only reference to the element at the given position.</returns>
            internal ref readonly T ItemRef(int index)
            {
                Requires.Range(index >= 0 && index < this.Count, nameof(index));
 
                return ref ItemRefUnchecked(index);
            }
 
            private ref readonly T ItemRefUnchecked(int index)
            {
                Debug.Assert(_left != null && _right != null);
 
                if (index < _left._count)
                {
                    return ref _left.ItemRefUnchecked(index);
                }
 
                if (index > _left._count)
                {
                    return ref _right.ItemRefUnchecked(index - _left._count - 1);
                }
 
                return ref _key;
            }
 
            #region IEnumerable<T> Members
 
            /// <summary>
            /// Returns an enumerator that iterates through the collection.
            /// </summary>
            /// <returns>
            /// A <see cref="IEnumerator{T}"/> that can be used to iterate through the collection.
            /// </returns>
            public Enumerator GetEnumerator()
            {
                return new Enumerator(this);
            }
 
            /// <summary>
            /// Returns an enumerator that iterates through the collection.
            /// </summary>
            /// <returns>
            /// A <see cref="IEnumerator{T}"/> that can be used to iterate through the collection.
            /// </returns>
            [ExcludeFromCodeCoverage] // internal and never called, but here for the interface.
            IEnumerator<T> IEnumerable<T>.GetEnumerator()
            {
                return this.GetEnumerator();
            }
 
            /// <summary>
            /// Returns an enumerator that iterates through the collection.
            /// </summary>
            /// <returns>
            /// A <see cref="IEnumerator{T}"/> that can be used to iterate through the collection.
            /// </returns>
            [ExcludeFromCodeCoverage] // internal and never called, but here for the interface.
            IEnumerator IEnumerable.GetEnumerator()
            {
                return this.GetEnumerator();
            }
 
            #endregion
 
            /// <summary>
            /// Returns an enumerator that iterates through the collection.
            /// </summary>
            /// <param name="builder">The builder, if applicable.</param>
            /// <returns>
            /// A <see cref="IEnumerator{T}"/> that can be used to iterate through the collection.
            /// </returns>
            internal Enumerator GetEnumerator(Builder builder)
            {
                return new Enumerator(this, builder);
            }
 
            /// <summary>
            /// See the <see cref="ICollection{T}"/> interface.
            /// </summary>
            internal void CopyTo(T[] array, int arrayIndex)
            {
                Requires.NotNull(array, nameof(array));
                Requires.Range(arrayIndex >= 0, nameof(arrayIndex));
                Requires.Range(array.Length >= arrayIndex + this.Count, nameof(arrayIndex));
                foreach (T item in this)
                {
                    array[arrayIndex++] = item;
                }
            }
 
            /// <summary>
            /// See the <see cref="ICollection{T}"/> interface.
            /// </summary>
            internal void CopyTo(Array array, int arrayIndex)
            {
                Requires.NotNull(array, nameof(array));
                Requires.Range(arrayIndex >= 0, nameof(arrayIndex));
                Requires.Range(array.Length >= arrayIndex + this.Count, nameof(arrayIndex));
 
                foreach (T item in this)
                {
                    array.SetValue(item, arrayIndex++);
                }
            }
 
            /// <summary>
            /// Adds the specified key to the tree.
            /// </summary>
            /// <param name="key">The key.</param>
            /// <param name="comparer">The comparer.</param>
            /// <param name="mutated">Receives a value indicating whether this node tree has mutated because of this operation.</param>
            /// <returns>The new tree.</returns>
            internal Node Add(T key, IComparer<T> comparer, out bool mutated)
            {
                Requires.NotNull(comparer, nameof(comparer));
 
                if (this.IsEmpty)
                {
                    mutated = true;
                    return new Node(key, this, this);
                }
                else
                {
                    Node result = this;
                    int compareResult = comparer.Compare(key, _key);
                    if (compareResult > 0)
                    {
                        ImmutableSortedSet<T>.Node newRight = _right!.Add(key, comparer, out mutated);
                        if (mutated)
                        {
                            result = this.Mutate(right: newRight);
                        }
                    }
                    else if (compareResult < 0)
                    {
                        ImmutableSortedSet<T>.Node newLeft = _left!.Add(key, comparer, out mutated);
                        if (mutated)
                        {
                            result = this.Mutate(left: newLeft);
                        }
                    }
                    else
                    {
                        mutated = false;
                        return this;
                    }
 
                    return mutated ? MakeBalanced(result) : result;
                }
            }
 
            /// <summary>
            /// Removes the specified key from the tree.
            /// </summary>
            /// <param name="key">The key.</param>
            /// <param name="comparer">The comparer.</param>
            /// <param name="mutated">Receives a value indicating whether this node tree has mutated because of this operation.</param>
            /// <returns>The new tree.</returns>
            internal Node Remove(T key, IComparer<T> comparer, out bool mutated)
            {
                Requires.NotNull(comparer, nameof(comparer));
 
                if (this.IsEmpty)
                {
                    mutated = false;
                    return this;
                }
                else
                {
                    Debug.Assert(_left != null && _right != null);
                    Node result = this;
                    int compare = comparer.Compare(key, _key);
                    if (compare == 0)
                    {
                        // We have a match.
                        mutated = true;
 
                        // If this is a leaf, just remove it
                        // by returning Empty.  If we have only one child,
                        // replace the node with the child.
                        if (_right.IsEmpty && _left.IsEmpty)
                        {
                            result = EmptyNode;
                        }
                        else if (_right.IsEmpty && !_left.IsEmpty)
                        {
                            result = _left;
                        }
                        else if (!_right.IsEmpty && _left.IsEmpty)
                        {
                            result = _right;
                        }
                        else
                        {
                            // We have two children. Remove the next-highest node and replace
                            // this node with it.
                            ImmutableSortedSet<T>.Node successor = _right;
                            while (!successor._left!.IsEmpty)
                            {
                                successor = successor._left;
                            }
 
                            ImmutableSortedSet<T>.Node newRight = _right.Remove(successor._key, comparer, out _);
                            result = successor.Mutate(left: _left, right: newRight);
                        }
                    }
                    else if (compare < 0)
                    {
                        ImmutableSortedSet<T>.Node newLeft = _left.Remove(key, comparer, out mutated);
                        if (mutated)
                        {
                            result = this.Mutate(left: newLeft);
                        }
                    }
                    else
                    {
                        ImmutableSortedSet<T>.Node newRight = _right.Remove(key, comparer, out mutated);
                        if (mutated)
                        {
                            result = this.Mutate(right: newRight);
                        }
                    }
 
                    return result.IsEmpty ? result : MakeBalanced(result);
                }
            }
 
            /// <summary>
            /// Determines whether the specified key is in this tree.
            /// </summary>
            /// <param name="key">The key.</param>
            /// <param name="comparer">The comparer.</param>
            /// <returns>
            ///   <c>true</c> if the tree contains the specified key; otherwise, <c>false</c>.
            /// </returns>
            internal bool Contains(T key, IComparer<T> comparer)
            {
                Requires.NotNull(comparer, nameof(comparer));
                return !this.Search(key, comparer).IsEmpty;
            }
 
            /// <summary>
            /// Freezes this node and all descendant nodes so that any mutations require a new instance of the nodes.
            /// </summary>
            internal void Freeze()
            {
                // If this node is frozen, all its descendants must already be frozen.
                if (!_frozen)
                {
                    Debug.Assert(_left != null && _right != null);
                    _left.Freeze();
                    _right.Freeze();
                    _frozen = true;
                }
            }
 
            /// <summary>
            /// Searches for the specified key.
            /// </summary>
            /// <param name="key">The key to search for.</param>
            /// <param name="comparer">The comparer.</param>
            /// <returns>The matching node, or <see cref="EmptyNode"/> if no match was found.</returns>
            internal Node Search(T key, IComparer<T> comparer)
            {
                Requires.NotNull(comparer, nameof(comparer));
 
                if (this.IsEmpty)
                {
                    return this;
                }
                else
                {
                    int compare = comparer.Compare(key, _key);
                    if (compare == 0)
                    {
                        return this;
                    }
                    else if (compare > 0)
                    {
                        return _right!.Search(key, comparer);
                    }
                    else
                    {
                        return _left!.Search(key, comparer);
                    }
                }
            }
 
            /// <summary>
            /// Searches for the specified key.
            /// </summary>
            /// <param name="key">The key to search for.</param>
            /// <param name="comparer">The comparer.</param>
            /// <returns>The matching node, or <see cref="EmptyNode"/> if no match was found.</returns>
            internal int IndexOf(T key, IComparer<T> comparer)
            {
                Requires.NotNull(comparer, nameof(comparer));
 
                if (this.IsEmpty)
                {
                    return -1;
                }
                else
                {
                    Debug.Assert(_left != null && _right != null);
                    int compare = comparer.Compare(key, _key);
                    if (compare == 0)
                    {
                        return _left.Count;
                    }
                    else if (compare > 0)
                    {
                        int result = _right.IndexOf(key, comparer);
                        bool missing = result < 0;
                        if (missing)
                        {
                            result = ~result;
                        }
 
                        result = _left.Count + 1 + result;
                        if (missing)
                        {
                            result = ~result;
                        }
 
                        return result;
                    }
                    else
                    {
                        return _left.IndexOf(key, comparer);
                    }
                }
            }
 
            /// <summary>
            /// Returns an <see cref="IEnumerable{T}"/> that iterates over this
            /// collection in reverse order.
            /// </summary>
            /// <returns>
            /// An enumerator that iterates over the <see cref="ImmutableSortedSet{T}"/>
            /// in reverse order.
            /// </returns>
            internal IEnumerator<T> Reverse()
            {
                return new Enumerator(this, reverse: true);
            }
 
            #region Tree balancing methods
 
            /// <summary>
            /// AVL rotate left operation.
            /// </summary>
            /// <param name="tree">The tree.</param>
            /// <returns>The rotated tree.</returns>
            private static Node RotateLeft(Node tree)
            {
                Requires.NotNull(tree, nameof(tree));
                Debug.Assert(!tree.IsEmpty);
 
                if (tree._right!.IsEmpty)
                {
                    return tree;
                }
 
                ImmutableSortedSet<T>.Node right = tree._right;
                return right.Mutate(left: tree.Mutate(right: right._left!));
            }
 
            /// <summary>
            /// AVL rotate right operation.
            /// </summary>
            /// <param name="tree">The tree.</param>
            /// <returns>The rotated tree.</returns>
            private static Node RotateRight(Node tree)
            {
                Requires.NotNull(tree, nameof(tree));
                Debug.Assert(!tree.IsEmpty);
 
                if (tree._left!.IsEmpty)
                {
                    return tree;
                }
 
                ImmutableSortedSet<T>.Node left = tree._left;
                return left.Mutate(right: tree.Mutate(left: left._right!));
            }
 
            /// <summary>
            /// AVL rotate double-left operation.
            /// </summary>
            /// <param name="tree">The tree.</param>
            /// <returns>The rotated tree.</returns>
            private static Node DoubleLeft(Node tree)
            {
                Requires.NotNull(tree, nameof(tree));
                Debug.Assert(!tree.IsEmpty);
 
                if (tree._right!.IsEmpty)
                {
                    return tree;
                }
 
                Node rotatedRightChild = tree.Mutate(right: RotateRight(tree._right));
                return RotateLeft(rotatedRightChild);
            }
 
            /// <summary>
            /// AVL rotate double-right operation.
            /// </summary>
            /// <param name="tree">The tree.</param>
            /// <returns>The rotated tree.</returns>
            private static Node DoubleRight(Node tree)
            {
                Requires.NotNull(tree, nameof(tree));
                Debug.Assert(!tree.IsEmpty);
 
                if (tree._left!.IsEmpty)
                {
                    return tree;
                }
 
                Node rotatedLeftChild = tree.Mutate(left: RotateLeft(tree._left));
                return RotateRight(rotatedLeftChild);
            }
 
            /// <summary>
            /// Returns a value indicating whether the tree is in balance.
            /// </summary>
            /// <param name="tree">The tree.</param>
            /// <returns>0 if the tree is in balance, a positive integer if the right side is heavy, or a negative integer if the left side is heavy.</returns>
            private static int Balance(Node tree)
            {
                Requires.NotNull(tree, nameof(tree));
                Debug.Assert(!tree.IsEmpty);
 
                return tree._right!._height - tree._left!._height;
            }
 
            /// <summary>
            /// Determines whether the specified tree is right heavy.
            /// </summary>
            /// <param name="tree">The tree.</param>
            /// <returns>
            /// <c>true</c> if [is right heavy] [the specified tree]; otherwise, <c>false</c>.
            /// </returns>
            private static bool IsRightHeavy(Node tree)
            {
                Requires.NotNull(tree, nameof(tree));
                Debug.Assert(!tree.IsEmpty);
                return Balance(tree) >= 2;
            }
 
            /// <summary>
            /// Determines whether the specified tree is left heavy.
            /// </summary>
            private static bool IsLeftHeavy(Node tree)
            {
                Requires.NotNull(tree, nameof(tree));
                Debug.Assert(!tree.IsEmpty);
                return Balance(tree) <= -2;
            }
 
            /// <summary>
            /// Balances the specified tree.
            /// </summary>
            /// <param name="tree">The tree.</param>
            /// <returns>A balanced tree.</returns>
            private static Node MakeBalanced(Node tree)
            {
                Requires.NotNull(tree, nameof(tree));
                Debug.Assert(!tree.IsEmpty);
 
                if (IsRightHeavy(tree))
                {
                    return Balance(tree._right!) < 0 ? DoubleLeft(tree) : RotateLeft(tree);
                }
 
                if (IsLeftHeavy(tree))
                {
                    return Balance(tree._left!) > 0 ? DoubleRight(tree) : RotateRight(tree);
                }
 
                return tree;
            }
 
            #endregion
 
            /// <summary>
            /// Creates a node tree that contains the contents of a list.
            /// </summary>
            /// <param name="items">An indexable list with the contents that the new node tree should contain.</param>
            /// <param name="start">The starting index within <paramref name="items"/> that should be captured by the node tree.</param>
            /// <param name="length">The number of elements from <paramref name="items"/> that should be captured by the node tree.</param>
            /// <returns>The root of the created node tree.</returns>
            internal static Node NodeTreeFromList(IReadOnlyList<T> items, int start, int length)
            {
                Requires.NotNull(items, nameof(items));
                Debug.Assert(start >= 0);
                Debug.Assert(length >= 0);
 
                if (length == 0)
                {
                    return EmptyNode;
                }
 
                int rightCount = (length - 1) / 2;
                int leftCount = (length - 1) - rightCount;
                Node left = NodeTreeFromList(items, start, leftCount);
                Node right = NodeTreeFromList(items, start + leftCount + 1, rightCount);
                return new Node(items[start + leftCount], left, right, true);
            }
 
            /// <summary>
            /// Creates a node mutation, either by mutating this node (if not yet frozen) or by creating a clone of this node
            /// with the described changes.
            /// </summary>
            /// <param name="left">The left branch of the mutated node.</param>
            /// <param name="right">The right branch of the mutated node.</param>
            /// <returns>The mutated (or created) node.</returns>
            private Node Mutate(Node? left = null, Node? right = null)
            {
                Debug.Assert(_left != null && _right != null);
                if (_frozen)
                {
                    return new Node(_key, left ?? _left, right ?? _right);
                }
                else
                {
                    if (left != null)
                    {
                        _left = left;
                    }
 
                    if (right != null)
                    {
                        _right = right;
                    }
 
                    _height = checked((byte)(1 + Math.Max(_left._height, _right._height)));
                    _count = 1 + _left._count + _right._count;
                    return this;
                }
            }
        }
    }
}