File: Utils\SequencePool.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Internal.Utilities
{
    using Conditional = System.Diagnostics.ConditionalAttribute;
 
    /// <summary>
    /// A dictionary of uint sequences of variable length. Stores the sequences as
    /// byte sequences encoded with LEB128. Empty sequences (or null) are also valid.
    /// </summary>
    [BestFriend]
    internal sealed class SequencePool
    {
        // uint sequences are hashed into _mask+1 buckets. _buckets contains the ID of the first
        // sequence that falls in it (or -1 if it is empty).
        // We store the sequences by using LEB128 encoding, and storing the resulting byte sequence.
        // For each sequence, we store its start index in the byte array and its hash value.
        // This information for sequence with ID id is stored in _start[id] and _hash[id]
        // respectively. In addition, we store the ID of the next sequence in the same bucket
        // in _next[id].
 
        // Number of buckets minus 1. The number of buckets must be a power of two.
        private int _mask;
        // The i'th entry contains the ID of the first sequence in the i'th bucket.
        private int[] _buckets;
 
        // The number of uint sequences. The arrays _next, _start and _hash have logical
        // length _idLim, and are indexed by the ID's of the sequences.
        private int _idLim;
        // For 0 <= id < _idLim, the ID of the next uint sequence in the same bucket as sequence id.
        private int[] _next;
        // Indexed by id. Contains the starting index in _bytes of the sequences. The logical length
        // of this array is _idLim+1, with the last item being the total count of bytes.
        private int[] _start;
        // Indexed by id. Contains the hashes of the sequences. The hashing is performed on the uncompressed uint
        // sequences.
        private uint[] _hash;
 
        // The byte sequences. The logical length of _bytes is _start[_idLim].
        private byte[] _bytes;
 
        public int Count { get { return _idLim; } }
 
        public SequencePool()
        {
            _mask = 31;
            _buckets = Utils.CreateArray<int>(_mask + 1, -1);
 
            _next = new int[10];
            _start = new int[11];
            _hash = new uint[10];
            _bytes = new byte[40];
 
            AssertValid();
        }
 
        public SequencePool(BinaryReader reader)
        {
            // *** Binary format ***
            // int: _idLim (the number of sequences)
            // int[]: _start (length is _idLim+1)
            // byte[]: _bytes (length is _start[_idLim])
 
            _idLim = reader.ReadInt32();
            Contracts.CheckDecode(0 <= _idLim && _idLim < int.MaxValue);
            _start = reader.ReadIntArray(_idLim + 1);
            Contracts.CheckDecode(Utils.Size(_start) > 0 && _start[0] == 0);
            Contracts.CheckDecode(_start[_idLim] >= 0);
            _bytes = reader.ReadByteArray(_start[_idLim]);
            if (_idLim < 10)
                Array.Resize(ref _start, 11);
            if (Utils.Size(_bytes) < 40)
                Array.Resize(ref _bytes, 40);
 
            // Find the smallest power of 2 that is greater than _idLim.
            int ibit = Utils.IbitHigh((uint)Math.Max(_idLim, 31));
            Contracts.Assert(4 <= ibit && ibit <= 31);
            if (ibit < 31)
                ibit++;
            _mask = (1 << ibit) - 1;
 
            _buckets = Utils.CreateArray<int>(_mask + 1, -1);
 
            _hash = new uint[Math.Max(_idLim, 10)];
            _next = new int[Math.Max(_idLim, 10)];
 
            uint[] sequence = null;
            var cb = _start[_idLim];
            for (int id = 0; id < _idLim; id++)
            {
                Contracts.CheckDecode(_start[id] <= _start[id + 1] && _start[id + 1] <= cb);
                int count = Leb128ToUIntArray(_bytes, _start[id], _start[id + 1], ref sequence);
                _hash[id] = Hashing.HashSequence(sequence, 0, count);
                int i = GetBucketIndex(_hash[id]);
                _next[id] = _buckets[i];
                _buckets[i] = id;
            }
 
            AssertValid();
        }
 
        public void Save(BinaryWriter writer)
        {
            AssertValid();
 
            // *** Binary format ***
            // int: _idLim (the number of sequences)
            // int[]: _start (length is _idLim+1)
            // byte[]: _bytes (length is _start[_idLim])
 
            writer.Write(_idLim);
#if DEBUG
            for (int id = 0; id < _idLim; id++)
                Contracts.Assert(_start[id] <= _start[id + 1]);
#endif
            writer.WriteIntsNoCount(_start.AsSpan(0, _idLim + 1));
            writer.WriteBytesNoCount(_bytes, _start[_idLim]);
        }
 
        [Conditional("DEBUG")]
        private void AssertValid()
        {
            // Number of buckets must be a power of two.
            Contracts.AssertValue(_buckets);
            Contracts.Assert(_buckets.Length == _mask + 1);
            Contracts.Assert(Utils.IsPowerOfTwo(_mask + 1));
 
            Contracts.Assert(0 <= _idLim && Math.Max(10, _idLim) <= Utils.Size(_start) - 1);
            Contracts.Assert(Math.Max(10, _idLim) <= Utils.Size(_hash));
            Contracts.Assert(Math.Max(10, _idLim) <= Utils.Size(_next));
            Contracts.Assert(_start[0] == 0);
            Contracts.Assert(0 <= _start[_idLim] && Math.Max(40, _start[_idLim]) <= Utils.Size(_bytes));
        }
 
        private int GetFirstIdInBucket(uint hash)
        {
            return _buckets[(int)hash & _mask];
        }
 
        private int GetBucketIndex(uint hash)
        {
            return (int)hash & _mask;
        }
 
        // Returns the ID of the requested sequence, or -1 if it is not found.
        private int GetCore(uint[] sequence, int min, int lim, out uint hash)
        {
            AssertValid();
            Contracts.Assert(0 <= min && min <= lim && lim <= Utils.Size(sequence));
 
            hash = Hashing.HashSequence(sequence, min, lim);
 
            for (int idCur = GetFirstIdInBucket(hash); idCur >= 0; idCur = _next[idCur])
            {
                Contracts.Assert(0 <= idCur && idCur < _idLim);
                if (_hash[idCur] != hash)
                    continue;
 
                var ibCur = _start[idCur];
                var ibLim = _start[idCur + 1];
                for (int i = min; ; i++)
                {
                    Contracts.Assert(ibCur <= ibLim);
                    if (i >= lim)
                    {
                        // Need to make sure that we have reached the end of the sequence in the pool at the
                        // same time that we reached the end of sequence.
                        if (ibCur == ibLim)
                            return idCur;
                        break;
                    }
                    if (ibCur >= ibLim)
                        break;
                    uint decoded;
                    var success = TryDecodeOne(_bytes, ref ibCur, _start[idCur + 1], out decoded);
                    Contracts.Assert(success);
                    if (sequence[i] != decoded)
                        break;
                }
            }
            return -1;
        }
 
        /// <summary>
        /// Returns true if the sequence was added, or false if it was already in the pool.
        /// </summary>
        /// <param name="sequence">The array containing the sequence to add to the pool.</param>
        /// <param name="min">The location in the array of the first sequence element.</param>
        /// <param name="lim">The exclusive end of the sequence.</param>
        /// <param name="id">To be populated with the id of the added sequence.</param>
        /// <returns>True if the sequence was added, false if the sequence was already present in the pool.</returns>
        public bool TryAdd(uint[] sequence, int min, int lim, out int id)
        {
            Contracts.Check(0 <= min && min <= lim && lim <= Utils.Size(sequence));
 
            uint hash;
            id = GetCore(sequence, min, lim, out hash);
            if (id >= 0)
                return false;
            id = _idLim;
            AddCore(sequence, min, lim, hash);
            Contracts.Assert(id == _idLim - 1);
            return true;
        }
 
        /// <summary>
        /// Find the given sequence in the pool. If not found, returns -1.
        /// </summary>
        /// <param name="sequence">An integer sequence</param>
        /// <param name="min">The starting index of the sequence to find in the pool</param>
        /// <param name="lim">The length of the sequence to find in the pool</param>
        /// <returns>The ID of the sequence if it is found, -1 otherwise</returns>
        public int Get(uint[] sequence, int min, int lim)
        {
            Contracts.Check(0 <= min && min <= lim && lim <= Utils.Size(sequence));
 
            uint hash;
            return GetCore(sequence, min, lim, out hash);
        }
 
        /// <summary>
        /// Adds the item. Does NOT check for whether the item is already present.
        /// </summary>
        private void AddCore(uint[] sequence, int min, int lim, uint hash)
        {
            Contracts.Assert(0 <= min && min <= lim && lim <= Utils.Size(sequence));
            Contracts.Assert(Hashing.HashSequence(sequence, min, lim) == hash);
 
            if (_idLim + 1 >= _start.Length)
            {
                Contracts.Check(_start.Length != Utils.ArrayMaxSize, "Cannot allocate memory for the sequence pool");
                Contracts.Assert(_idLim + 1 == _start.Length);
                long newSize = (long)_start.Length + _start.Length / 2;
                int size = (newSize > Utils.ArrayMaxSize) ? Utils.ArrayMaxSize : (int)newSize;
                Array.Resize(ref _start, size);
            }
 
            Contracts.Assert(_hash.Length >= _next.Length);
            if (_idLim >= _next.Length)
            {
                Contracts.Check(_next.Length != Utils.ArrayMaxSize, "Cannot allocate memory for the sequence pool");
                Contracts.Assert(_idLim == _next.Length);
                long newSize = (long)_next.Length + _next.Length / 2;
                int size = (newSize > Utils.ArrayMaxSize) ? Utils.ArrayMaxSize : (int)newSize;
                Array.Resize(ref _hash, size);
                Array.Resize(ref _next, size);
            }
 
            var cbMax = checked(5 * (lim - min));
            var ibLim = _start[_idLim];
            if (ibLim > _bytes.Length - cbMax)
            {
                Contracts.Check(_bytes.Length != Utils.ArrayMaxSize, "Cannot allocate memory for the sequence pool");
                long newSize = Math.Max((long)_bytes.Length + _bytes.Length / 2, ibLim + cbMax);
                int size = (newSize > Utils.ArrayMaxSize) ? Utils.ArrayMaxSize : (int)newSize;
                Array.Resize(ref _bytes, size);
            }
            Contracts.Assert(_idLim < _next.Length);
            Contracts.Assert(ibLim <= _bytes.Length - cbMax);
 
            int i = GetBucketIndex(hash);
            _next[_idLim] = _buckets[i];
            _hash[_idLim] = hash;
            _buckets[i] = _idLim;
            _idLim++;
            _start[_idLim] = _start[_idLim - 1];
            UIntArrayToLeb128(sequence, min, lim, _bytes, ref _start[_idLim]);
 
            if (_idLim >= _buckets.Length)
                GrowTable();
 
            AssertValid();
        }
 
        private void GrowTable()
        {
            AssertValid();
 
            int size = checked(2 * _buckets.Length);
            _buckets = Utils.CreateArray<int>(size, -1);
            _mask = size - 1;
 
            for (int id = 0; id < _idLim; id++)
            {
                int i = GetBucketIndex(_hash[id]);
                _next[id] = _buckets[i];
                _buckets[i] = id;
            }
 
            AssertValid();
        }
 
        // populates sequence with the integers in sequence number id, and returns the count.
        public int GetById(int id, ref uint[] sequence)
        {
            Contracts.Check(0 <= id && id < _idLim);
            return Leb128ToUIntArray(_bytes, _start[id], _start[id + 1], ref sequence);
        }
 
        // Asserts that byteSequences is big enough.
        private static void UIntArrayToLeb128(uint[] values, int min, int lim, byte[] bytes, ref int ib)
        {
            Contracts.Assert(bytes.Length >= ib + 5 * (lim - min));
            uint value;
            for (int i = min; i < lim; i++)
            {
                value = values[i];
                // Copied from Utils.WriteLEB128Int
                while (value >= 0x80)
                {
                    bytes[ib++] = (byte)(value | 0x80);
                    value >>= 7;
                }
                bytes[ib++] = (byte)value;
            }
        }
 
        private static bool TryDecodeOne(byte[] bytes, ref int ib, int ibLim, out uint value)
        {
            value = 0;
            int shift = 0;
            for (; ib < ibLim; ib++)
            {
                uint bCur = bytes[ib];
                if (shift == 4 * 7 && bCur > 0x0F)
                    return false;
 
                value |= (((bCur & (uint)0x7F)) << shift);
                shift += 7;
                if ((bCur & 0x80) == 0)
                {
                    ib++;
                    return true;
                }
            }
            return false;
        }
 
        private static int Leb128ToUIntArray(byte[] bytes, int min, int lim, ref uint[] sequence)
        {
            Contracts.Assert(0 <= min && min <= lim && lim <= Utils.Size(bytes));
 
            int cur = min;
            int count = 0;
            while (cur < lim)
            {
                if (Utils.Size(sequence) <= count)
                {
                    Contracts.Assert(count < lim - min);
                    Array.Resize(ref sequence, lim - min);
                }
                Contracts.CheckDecode(TryDecodeOne(bytes, ref cur, lim, out sequence[count]));
                Contracts.Assert(cur <= lim);
                count++;
            }
            return count;
        }
    }
}