File: Dataset\IntArray.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
    using FloatType = System.Single;
#else
#endif
 
    internal enum IntArrayType { Dense, Sparse, Repeat, Segmented, Current };
    internal enum IntArrayBits { Bits32 = 32, Bits16 = 16, Bits10 = 10, Bits8 = 8, Bits4 = 4, Bits1 = 1, Bits0 = 0 };
 
    /// <summary>
    /// An object representing an array of integers
    /// </summary>
    internal abstract class IntArray : IEnumerable<int>
    {
        // The level of compression to use with features.
        // 0x1 - Use 10 bit.
        // 0x2 -
        public static int CompatibilityLevel = 0;
 
        /// <summary>
        /// The virtual length of the array
        /// </summary>
        public abstract int Length { get; }
 
        /// <summary>
        /// Bool that checks if we are on x86/x64 so we know if we should use the native code
        /// or the managed fallbacks.
        /// </summary>
        public static bool UseFastTreeNative => RuntimeInformation.ProcessArchitecture == Architecture.X64 || RuntimeInformation.ProcessArchitecture == Architecture.X86;
 
        /// <summary>
        /// Returns the number of bytes written by the member ToByteArray()
        /// </summary>
        public virtual int SizeInBytes()
        {
            return 2 * sizeof(int);
        }
 
        /// <summary>
        /// Writes a binary representation of this class to a byte buffer, at a given position.
        /// The position is incremented to the end of the representation
        /// </summary>
        /// <param name="buffer">a byte array where the binary representation is written</param>
        /// <param name="position">the position in the byte array</param>
        public virtual void ToByteArray(byte[] buffer, ref int position)
        {
            ((int)Type).ToByteArray(buffer, ref position);
            ((int)BitsPerItem).ToByteArray(buffer, ref position);
        }
 
        public abstract IntArrayBits BitsPerItem { get; }
 
        public abstract IntArrayType Type { get; }
 
        /// <summary>
        /// Number of bytes needed to store this number of values
        /// </summary>
        public static IntArrayBits NumBitsNeeded(int numValues)
        {
            Contracts.CheckParam(numValues >= 0, nameof(numValues));
            if (numValues <= (1 << 0))
                return IntArrayBits.Bits0;
            else if (numValues <= (1 << 1))
                return IntArrayBits.Bits1;
            else if (numValues <= (1 << 4))
                return IntArrayBits.Bits4;
            else if (numValues <= (1 << 8))
                return IntArrayBits.Bits8;
            else if ((CompatibilityLevel & 1) != 0 && numValues <= (1 << 10))
                return IntArrayBits.Bits10;
            else if (numValues <= (1 << 16))
                return IntArrayBits.Bits16;
            else
                return IntArrayBits.Bits32;
        }
 
        public static IntArray New(int length, IntArrayType type, IntArrayBits bitsPerItem, IEnumerable<int> values)
        {
            Contracts.CheckParam(length >= 0, nameof(length));
            Contracts.CheckParam(Enum.IsDefined(typeof(IntArrayType), type) && type != IntArrayType.Current, nameof(type));
            Contracts.CheckParam(Enum.IsDefined(typeof(IntArrayBits), bitsPerItem), nameof(bitsPerItem));
            Contracts.CheckValue(values, nameof(values));
 
            if (type == IntArrayType.Dense || bitsPerItem == IntArrayBits.Bits0)
            {
                if (bitsPerItem == IntArrayBits.Bits0)
                {
                    Contracts.Assert(values.All(x => x == 0));
                    return new Dense0BitIntArray(length);
                }
                //else if (bitsPerItem == IntArrayBits.Bits1) return new Dense1BitIntArray(length);
                else if (bitsPerItem <= IntArrayBits.Bits4)
                    return new Dense4BitIntArray(length, values);
                else if (bitsPerItem <= IntArrayBits.Bits8)
                    return new Dense8BitIntArray(length, values);
                else if (bitsPerItem <= IntArrayBits.Bits10)
                    return new Dense10BitIntArray(length, values);
                else if (bitsPerItem <= IntArrayBits.Bits16)
                    return new Dense16BitIntArray(length, values);
                else
                    return new Dense32BitIntArray(length, values);
            }
            else if (type == IntArrayType.Sparse)
                return new DeltaSparseIntArray(length, bitsPerItem, values);
            else if (type == IntArrayType.Repeat)
                return new DeltaRepeatIntArray(length, bitsPerItem, values);
            else if (type == IntArrayType.Segmented)
                // Segmented should probably not be used in this way.
                return new SegmentIntArray(length, values);
            return null;
        }
 
        public static IntArray New(int length, IntArrayType type, IntArrayBits bitsPerItem)
        {
            Contracts.CheckParam(length >= 0, nameof(length));
            Contracts.CheckParam(type == IntArrayType.Current || type == IntArrayType.Repeat || type == IntArrayType.Segmented, nameof(type));
 
            if (type == IntArrayType.Dense || bitsPerItem == IntArrayBits.Bits0)
            {
                if (bitsPerItem == IntArrayBits.Bits0)
                    return new Dense0BitIntArray(length);
                //else if (bitsPerItem <= IntArrayBits.Bits1) return new Dense1BitIntArray(length);
                else if (bitsPerItem <= IntArrayBits.Bits4)
                    return new Dense4BitIntArray(length);
                else if (bitsPerItem <= IntArrayBits.Bits8)
                    return new Dense8BitIntArray(length);
                else if (bitsPerItem <= IntArrayBits.Bits10)
                    return new Dense10BitIntArray(length);
                else if (bitsPerItem <= IntArrayBits.Bits16)
                    return new Dense16BitIntArray(length);
                else
                    return new Dense32BitIntArray(length);
            }
            else if (type == IntArrayType.Sparse)
                return new DeltaSparseIntArray(length, bitsPerItem);
            // REVIEW: ??? What is this?
            return null;
        }
 
        /// <summary>
        /// Creates a new int array given a byte representation
        /// </summary>
        /// <param name="buffer">the byte array representation of the dense array. The buffer can be larger than needed since the caller might be re-using buffers from a pool</param>
        /// <param name="position">the position in the byte array</param>
        /// <returns>the int array object</returns>
        public static IntArray New(byte[] buffer, ref int position)
        {
            IntArrayType type = (IntArrayType)buffer.ToInt(ref position);
            IntArrayBits bitsPerItem = (IntArrayBits)buffer.ToInt(ref position);
 
            if (type == IntArrayType.Dense)
            {
                if (bitsPerItem == IntArrayBits.Bits0)
                    return new Dense0BitIntArray(buffer, ref position);
                else if (bitsPerItem == IntArrayBits.Bits4)
                    return new Dense4BitIntArray(buffer, ref position);
                else if (bitsPerItem == IntArrayBits.Bits8)
                    return new Dense8BitIntArray(buffer, ref position);
                else if (bitsPerItem == IntArrayBits.Bits10)
                    return new Dense10BitIntArray(buffer, ref position);
                else if (bitsPerItem == IntArrayBits.Bits16)
                    return new Dense16BitIntArray(buffer, ref position);
                else
                    return new Dense32BitIntArray(buffer, ref position);
            }
            else if (type == IntArrayType.Sparse)
                return new DeltaSparseIntArray(buffer, ref position);
            else if (type == IntArrayType.Repeat)
                return new DeltaRepeatIntArray(buffer, ref position);
            else if (type == IntArrayType.Segmented)
                return new SegmentIntArray(buffer, ref position);
            return null;
        }
 
        /// <summary>
        /// Clones the contents of this IntArray into an new IntArray
        /// </summary>
        /// <param name="bitsPerItem">The number of bits per item in the created IntArray</param>
        /// <param name="type">The type of the new IntArray</param>
        public abstract IntArray Clone(IntArrayBits bitsPerItem, IntArrayType type);
 
        /// <summary>
        /// Clone an IntArray containing only the items indexed by <paramref name="itemIndices"/>
        /// </summary>
        /// <param name="itemIndices"> item indices will be contained in the cloned IntArray  </param>
        /// <returns> The cloned IntArray </returns>
        public abstract IntArray Clone(int[] itemIndices);
 
        public abstract IntArray[] Split(int[][] assignment);
 
        /// <summary>
        /// Gets an indexer into the array
        /// </summary>
        /// <returns>An indexer into the array</returns>
        public abstract IIntArrayForwardIndexer GetIndexer();
 
        // Used in the child classes so we can set either the native or managed Sumup method one time and then
        // never have to check again.
        protected delegate void PerformSumup(SumupInputData input, FeatureHistogram histogram);
 
        // Handler so the child classes don't have to redefine it. If they don't have different logic for native vs managed
        // code then they don't need to use this.
        protected PerformSumup SumupHandler { get; set; }
 
        // Helper to setup the SumupHandler for the derived classes that need it.
        protected void SetupSumupHandler(PerformSumup native, PerformSumup managed) => SumupHandler = UseFastTreeNative ? native : managed;
 
        public virtual void Sumup(SumupInputData input, FeatureHistogram histogram)
        {
            Contracts.Assert((input.Weights == null) == (histogram.SumWeightsByBin == null));
            if (histogram.SumWeightsByBin != null)
            {
                SumupWeighted(input, histogram);
                return;
            }
            IIntArrayForwardIndexer indexer = GetIndexer();
            for (int i = 0; i < input.TotalCount; i++)
            {
                int featureBin = input.DocIndices == null ? indexer[i] : indexer[input.DocIndices[i]];
                if (featureBin < 0
                    || featureBin >= histogram.SumTargetsByBin.Length
                    || featureBin >= histogram.NumFeatureValues)
                {
                    throw Contracts.Except("Feature bin {0} is invalid", featureBin);
                }
 
                histogram.SumTargetsByBin[featureBin] += input.Outputs[i];
                ++histogram.CountByBin[featureBin];
            }
        }
 
        private void SumupWeighted(SumupInputData input, FeatureHistogram histogram)
        {
            Contracts.AssertValue(histogram.SumWeightsByBin);
            Contracts.AssertValue(input.Weights);
            IIntArrayForwardIndexer indexer = GetIndexer();
            for (int i = 0; i < input.TotalCount; i++)
            {
                int featureBin = input.DocIndices == null ? indexer[i] : indexer[input.DocIndices[i]];
                if (featureBin < 0
                    || featureBin >= histogram.SumTargetsByBin.Length
                    || featureBin >= histogram.NumFeatureValues)
                {
                    throw Contracts.Except("Feature bin {0} is invalid", featureBin);
                }
 
                histogram.SumTargetsByBin[featureBin] += input.Outputs[i];
                histogram.SumWeightsByBin[featureBin] += input.Weights[i];
                ++histogram.CountByBin[featureBin];
            }
        }
 
        public abstract IEnumerator<int> GetEnumerator();
 
        System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }
 
        public override int GetHashCode()
        {
            int hash = 0;
            foreach (int i in this)
                hash ^= i.GetHashCode();
            return hash;
        }
 
        /// <summary>
        /// Finds the most space efficient representation of the feature
        /// (with slight slack cut for dense features). The behavior of
        /// this method depends upon the static value <see cref="CompatibilityLevel"/>.
        /// </summary>
        /// <param name="workarray">Should be non-null if you want it to
        /// consider segment arrays.</param>
        /// <returns>Returns a more space efficient version of the array,
        /// or the item itself if that is impossible, somehow.</returns>
        public IntArray Compress(uint[] workarray = null)
        {
            int maxval = 0;
            int zerocount = 0;
            int runs = 0;
            int last = -1;
            int overflows = 0;
            int zoverflows = 0;
            int runnow = 0; // The longest run of having the same value.
            int len = Length;
            IIntArrayForwardIndexer ind = GetIndexer();
            for (int i = 0; i < len; ++i)
            {
                int val = ind[i];
                if (workarray != null)
                    workarray[i] = (uint)val;
                if (val == 0)
                    zerocount++;
                else if (val > maxval)
                    maxval = val;
                if (last == val)
                {
                    runs++;
                    if (++runnow > byte.MaxValue)
                    {
                        // We have 256 items in a row the same.
                        overflows++;
                        if (val == 0)
                            zoverflows++;
                        runnow = 0;
                    }
                }
                last = val;
            }
            // Estimate the costs of the available options.
            IntArrayBits classicBits = IntArray.NumBitsNeeded(maxval + 1);
            long denseBits = (long)classicBits * (long)Length;
            long sparseBits = (long)(Math.Max((int)classicBits, 8) + 8) * (long)(Length - zerocount + zoverflows);
            long rleBits = (long)(classicBits + 8) * (long)(Length - runs + overflows);
            long segBits = long.MaxValue;
            int segTransitions = 0;
            if (workarray != null)
            {
                int bits = SegmentIntArray.BitsForValue((uint)maxval);
                if (bits <= 21)
                {
                    SegmentIntArray.SegmentFindOptimalPath.Value(workarray, Length,
                        bits, out segBits, out segTransitions);
                }
            }
            if ((IntArray.CompatibilityLevel & 0x4) == 0)
            {
                rleBits = long.MaxValue;
            }
            long bestCost = Math.Min(Math.Min(Math.Min(denseBits, sparseBits), rleBits), segBits);
            IntArrayType bestType = IntArrayType.Dense;
            if (bestCost >= denseBits * 98 / 100)
            {
                // Cut the dense bits a wee bit of slack.
            }
            else if (bestCost == sparseBits)
            {
                bestType = IntArrayType.Sparse;
            }
            else if (bestCost == rleBits)
            {
                bestType = IntArrayType.Repeat;
            }
            else
            {
                bestType = IntArrayType.Segmented;
            }
            if (bestType == Type && classicBits == BitsPerItem)
            {
                return this;
            }
            IntArray bins = null;
            if (bestType != IntArrayType.Segmented)
            {
                bins = IntArray.New(Length, bestType, classicBits, this);
            }
            else
            {
                bins = SegmentIntArray.FromWorkArray(workarray, Length, segBits, segTransitions);
            }
            return bins;
        }
    }
 
    /// <summary>
    /// Interface for objects that can index into an <see cref="IntArray"/>, but only with a non-decreasing sequence of indices.
    /// </summary>
    internal interface IIntArrayForwardIndexer
    {
        /// <summary>
        /// Gets the element at the given index.
        /// </summary>
        /// <param name="index">Index to get</param>
        /// <returns>The value at the index</returns>
        int this[int index] { get; }
    }
}