File: Dataset\DenseIntArray.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Security;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
    using FloatType = System.Single;
#else
    using FloatType = System.Double;
#endif
 
    /// <summary>
    /// Abstract class implementing some common functions of the dense int array types.
    /// </summary>
    internal abstract class DenseIntArray : IntArray, IIntArrayForwardIndexer
    {
        public override IntArrayType Type { get { return IntArrayType.Dense; } }
 
        protected DenseIntArray(int length, PerformSumup sumupHandler = null)
        {
            Contracts.Assert(length >= 0);
            Length = length;
            SumupHandler = sumupHandler;
        }
 
        public override int Length { get; }
 
        /// <summary>
        /// Gets or sets the value at this index.
        /// Value must be in legal range 0...((2^<see cref="IntArray.BitsPerItem"/>)-1).
        /// </summary>
        /// <param name="index">Index of value to get or set</param>
        /// <returns>The value at this index</returns>
        public abstract int this[int index] { get; set; }
 
        public override IntArray Clone(IntArrayBits bitsPerItem, IntArrayType type)
        {
            if (type == IntArrayType.Current)
                type = IntArrayType.Dense;
            return New(Length, type, bitsPerItem, this);
        }
 
        /// <summary>
        /// Clone an IntArray containing only the items indexed by <paramref name="itemIndices"/>
        /// </summary>
        /// <param name="itemIndices"> item indices will be contained in the cloned IntArray  </param>
        /// <returns> The cloned IntArray </returns>
        public override IntArray Clone(int[] itemIndices)
        {
            return IntArray.New(itemIndices.Length, IntArrayType.Dense, BitsPerItem, itemIndices.Select(x => this[x]));
        }
 
        public override IntArray[] Split(int[][] assignment)
        {
            int numParts = assignment.Length;
            IntArray[] newArrays = new IntArray[numParts];
 
            for (int p = 0; p < numParts; ++p)
            {
                newArrays[p] = IntArray.New(assignment[p].Length, IntArrayType.Dense, BitsPerItem, assignment[p].Select(x => this[x]));
            }
 
            return newArrays;
        }
 
        internal const string NativePath = "FastTreeNative";
        [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
        private static extern unsafe int C_Sumup_float(
            int numBits, byte* pData, int* pIndices, float* pSampleOutputs, double* pSampleOutputWeights,
            FloatType* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin,
            int totalCount, double totalSampleOutputs, double totalSampleOutputWeights);
 
        [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
        private static extern unsafe int C_Sumup_double(
            int numBits, byte* pData, int* pIndices, double* pSampleOutputs, double* pSampleOutputWeights,
            FloatType* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin,
            int totalCount, double totalSampleOutputs, double totalSampleOutputWeights);
 
        protected static unsafe void SumupCPlusPlusDense(SumupInputData input, FeatureHistogram histogram,
            byte* data, int numBits)
        {
            using (Timer.Time(TimerEvent.SumupCppDense))
            {
                fixed (FloatType* pSumTargetsByBin = histogram.SumTargetsByBin)
                fixed (FloatType* pSampleOutputs = input.Outputs)
                fixed (double* pSumWeightsByBin = histogram.SumWeightsByBin)
                fixed (double* pSampleWeights = input.Weights)
                fixed (int* pIndices = input.DocIndices)
                fixed (int* pCountByBin = histogram.CountByBin)
                {
                    int rv =
#if USE_SINGLE_PRECISION
                        C_Sumup_float
#else
                        C_Sumup_double
#endif
                        (numBits, data, pIndices, pSampleOutputs, pSampleWeights,
                         pSumTargetsByBin, pSumWeightsByBin, pCountByBin,
                         input.TotalCount, input.SumTargets, input.SumWeights);
                    if (rv < 0)
                        throw Contracts.Except("CSumup returned error {0}", rv);
                }
            }
        }
 
        public override IIntArrayForwardIndexer GetIndexer()
        {
            return this;
        }
 
        #region IEnumerable<int> Members
 
        public override IEnumerator<int> GetEnumerator()
        {
            for (int i = 0; i < Length; ++i)
                yield return this[i];
        }
 
        #endregion
    }
 
    internal abstract class DenseDataCallbackIntArray : DenseIntArray
    {
        protected DenseDataCallbackIntArray(int length)
            : base(length)
        {
        }
 
        public abstract void Callback(Action<IntPtr> callback);
    }
 
    /// <summary>
    /// A "null" feature representing only zeros.
    /// </summary>
    internal sealed class Dense0BitIntArray : DenseIntArray
    {
        public override IntArrayBits BitsPerItem { get { return IntArrayBits.Bits0; } }
 
        public Dense0BitIntArray(int length)
            : base(length)
        {
        }
 
        public Dense0BitIntArray(byte[] buffer, ref int position)
            : base(buffer.ToInt(ref position))
        {
        }
 
        /// <summary>
        /// Returns the number of bytes written by the member ToByteArray()
        /// </summary>
        public override int SizeInBytes()
        {
            return sizeof(int) + base.SizeInBytes();
        }
 
        /// <summary>
        /// Writes a binary representation of this class to a byte buffer, at a given position.
        /// The position is incremented to the end of the representation
        /// </summary>
        /// <param name="buffer">a byte array where the binary representation is written</param>
        /// <param name="position">the position in the byte array</param>
        public override void ToByteArray(byte[] buffer, ref int position)
        {
            base.ToByteArray(buffer, ref position);
            Length.ToByteArray(buffer, ref position);
        }
 
        public override int this[int index]
        {
            get
            {
                Contracts.Assert(0 <= index && index < Length);
                return 0;
            }
 
            set
            {
                Contracts.Assert(0 <= index && index < Length);
                Contracts.Assert(value == 0);
            }
        }
 
        public override void Sumup(SumupInputData input, FeatureHistogram histogram)
        {
            histogram.SumTargetsByBin[0] = input.SumTargets;
            if (histogram.SumWeightsByBin != null)
                histogram.SumWeightsByBin[0] = input.SumWeights;
            histogram.CountByBin[0] = input.TotalCount;
        }
    }
 
    /// <summary>
    /// A class to represent features using 10 bits.
    /// </summary>
    internal sealed class Dense10BitIntArray : DenseIntArray
    {
        private const int _bits = 10;
        private const int _mask = (1 << _bits) - 1;
        private readonly uint[] _data;
 
        public override IntArrayBits BitsPerItem { get { return IntArrayBits.Bits10; } }
 
        public Dense10BitIntArray(int len)
            : base(len)
        {
            _data = new uint[((((long)len) * _bits) >> 5) + 2];
        }
 
        public Dense10BitIntArray(byte[] buffer, ref int position)
            : base(buffer.ToInt(ref position))
        {
            _data = buffer.ToUIntArray(ref position);
        }
 
        public Dense10BitIntArray(int len, IEnumerable<int> values)
            : this(len)
        {
            int i = 0;
            long offset = 0;
            foreach (int val in values)
            {
                if (i++ > len)
                    break;
                Set(offset, _mask, val);
                offset += _bits;
            }
        }
 
        private int Get(long offset, byte bits)
        {
            return Get(offset, ~((uint)((-1) << bits)));
        }
 
        private int Get(long offset, uint mask)
        {
            int minor = (int)(offset & 0x1f);
            int major = (int)(offset >> 5);
            return (int)((uint)((_data[major] | (((ulong)_data[major + 1]) << 32)) >> minor) & mask);
        }
 
        private void Set(long offset, byte bits, int value)
        {
            Set(offset, ~((uint)((-1) << bits)), value);
        }
 
        private void Set(long offset, uint mask, int value)
        {
            int minor = (int)(offset & 0x1f);
            int major = (int)(offset >> 5);
 
            uint major0Mask = mask << minor;
            uint major1Mask = (uint)((((ulong)mask) << minor) >> 32);
 
            ulong val = ((((ulong)value) & mask) << minor);
            _data[major] = (_data[major] & ~major0Mask) | (uint)val;
            _data[major + 1] = (_data[major + 1] & ~major1Mask) | (uint)(val >> 32);
        }
 
        /// <summary>
        /// Returns the number of bytes written by the member ToByteArray()
        /// </summary>
        public override int SizeInBytes()
        {
            return _data.SizeInBytes() + sizeof(int) + base.SizeInBytes();
        }
 
        /// <summary>
        /// Writes a binary representation of this class to a byte buffer, at a given position.
        /// The position is incremented to the end of the representation
        /// </summary>
        /// <param name="buffer">a byte array where the binary representation is written</param>
        /// <param name="position">the position in the byte array</param>
        public override void ToByteArray(byte[] buffer, ref int position)
        {
            base.ToByteArray(buffer, ref position);
            Length.ToByteArray(buffer, ref position);
            _data.ToByteArray(buffer, ref position);
        }
 
        public sealed override unsafe int this[int index]
        {
            get
            {
                long offset = index;
                offset = (offset << 3) + (offset << 1);
                int minor = (int)(offset & 0x1f);
                int major = (int)(offset >> 5);
                fixed (uint* pData = _data)
                    return (int)(((*(ulong*)(pData + major)) >> minor) & _mask);
            }
 
            set
            {
                Contracts.Assert(0 <= value && value < (1 << 10));
                Set(((long)index) * 10, _mask, value);
            }
        }
 
        private void SumupRoot(FeatureHistogram histogram, FloatType[] outputs, double[] weights)
        {
            int fval;
            long offset = 0;
            for (int i = 0; i < Length; ++i)
            {
                fval = Get(offset, _mask);
                histogram.SumTargetsByBin[fval] += outputs[i];
                if (histogram.SumWeightsByBin != null)
                    histogram.SumWeightsByBin[fval] += weights[i];
                ++histogram.CountByBin[fval];
                offset += _bits;
            }
        }
 
        public override unsafe void Sumup(SumupInputData input, FeatureHistogram histogram)
        {
            using (Timer.Time(TimerEvent.SumupDense10))
            {
                if (input.DocIndices == null)
                {
                    SumupRoot(histogram, input.Outputs, input.Weights);
                    return;
                }
 
                int fval = 0;
                fixed (uint* pData = _data)
                fixed (int* pCountByBin = histogram.CountByBin)
                fixed (int* pDocIndicies = input.DocIndices)
                fixed (FloatType* pSumTargetsByBin = histogram.SumTargetsByBin)
                fixed (FloatType* pTargets = input.Outputs)
                {
                    if (histogram.SumWeightsByBin != null)
                    {
                        fixed (double* pSumWeightsByBin = histogram.SumWeightsByBin)
                        fixed (double* pWeights = input.Weights)
                        {
                            for (int ii = 0; ii < input.TotalCount; ++ii)
                            {
                                long offset = pDocIndicies[ii];
                                offset = (offset << 3) + (offset << 1);
                                int minor = (int)(offset & 0x1f);
                                int major = (int)(offset >> 5);
                                fval = (int)(((*(ulong*)(pData + major)) >> minor) & _mask);
                                pSumTargetsByBin[fval] += pTargets[ii];
                                pSumWeightsByBin[fval] += pWeights[ii];
                                ++pCountByBin[fval];
                            }
                        }
                    }
                    else
                    {
                        int end = input.TotalCount;
                        for (int ii = 0; ii < end; ++ii)
                        {
                            long offset = pDocIndicies[ii];
                            offset = (offset << 3) + (offset << 1);
                            int minor = (int)(offset & 0x1f);
                            int major = (int)(offset >> 5);
                            fval = (int)(((*(ulong*)(pData + major)) >> minor) & _mask);
                            pSumTargetsByBin[fval] += pTargets[ii];
                            ++pCountByBin[fval];
                        }
                    }
                }
            }
        }
    }
 
    /// <summary>
    /// A class to represent features using 8 bits
    /// </summary>
    /// <remarks>Represents values -1...(2^s-2)
    /// 0-bit array only represents the value -1</remarks>
    internal sealed class Dense8BitIntArray : DenseDataCallbackIntArray
    {
        private readonly byte[] _data;
 
        public override IntArrayBits BitsPerItem { get { return IntArrayBits.Bits8; } }
 
        public Dense8BitIntArray(int len)
            : base(len)
        {
            _data = new byte[len];
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        public Dense8BitIntArray(byte[] buffer, ref int position)
            : base(buffer.ToInt(ref position))
        {
            _data = buffer.ToByteArray(ref position);
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        public Dense8BitIntArray(int len, IEnumerable<int> values)
            : base(len)
        {
            _data = values.Select(i => (byte)i).ToArray(len);
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        /// <summary>
        /// Returns the number of bytes written by the member ToByteArray()
        /// </summary>
        public override int SizeInBytes()
        {
            return _data.SizeInBytes() + sizeof(int) + base.SizeInBytes();
        }
 
        /// <summary>
        /// Writes a binary representation of this class to a byte buffer, at a given position.
        /// The position is incremented to the end of the representation
        /// </summary>
        /// <param name="buffer">a byte array where the binary representation is written</param>
        /// <param name="position">the position in the byte array</param>
        public override void ToByteArray(byte[] buffer, ref int position)
        {
            base.ToByteArray(buffer, ref position);
            Length.ToByteArray(buffer, ref position);
            _data.ToByteArray(buffer, ref position);
        }
 
        public override unsafe void Callback(Action<IntPtr> callback)
        {
            fixed (byte* pData = _data)
            {
                callback((IntPtr)pData);
            }
        }
 
        public override unsafe int this[int index]
        {
            get { return _data[index]; }
 
            set
            {
                Contracts.Assert(0 <= value && value <= byte.MaxValue);
                _data[index] = (byte)value;
            }
        }
 
        private void SumupNative(SumupInputData input, FeatureHistogram histogram)
        {
            unsafe
            {
                fixed (byte* pData = _data)
                {
                    SumupCPlusPlusDense(input, histogram, pData, 8);
                }
            }
        }
 
        public override void Sumup(SumupInputData input, FeatureHistogram histogram) => SumupHandler(input, histogram);
    }
 
    /// <summary>
    /// A class to represent features using 4 bits.
    /// </summary>
    internal sealed class Dense4BitIntArray : DenseIntArray
    {
        /// <summary>
        /// For a given byte, the high 4 bits is the first value, the low 4 bits is the next value.
        /// </summary>
        private readonly byte[] _data;
 
        public override IntArrayBits BitsPerItem { get { return IntArrayBits.Bits4; } }
 
        public Dense4BitIntArray(int len)
            : base(len)
        {
            _data = new byte[(len + 1) / 2]; // Even length = half the bytes. Odd length = half the bytes+0.5.
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        public Dense4BitIntArray(int len, IEnumerable<int> values)
            : base(len)
        {
            _data = new byte[(len + 1) / 2];
            SetupSumupHandler(SumupNative, base.Sumup);
 
            int currentIndex = 0;
            bool upper = true;
            foreach (int value in values)
            {
                byte b = (byte)value;
                if (upper)
                {
                    _data[currentIndex] = (byte)(b << 4);
                    upper = false;
                }
                else
                {
                    _data[currentIndex] |= (byte)(b & 0x0f);
                    currentIndex++;
                    upper = true;
                }
            }
        }
 
        public Dense4BitIntArray(byte[] buffer, ref int position)
            : base(buffer.ToInt(ref position))
        {
            _data = buffer.ToByteArray(ref position);
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        /// <summary>
        /// Returns the number of bytes written by the member ToByteArray()
        /// </summary>
        public override int SizeInBytes()
        {
            return _data.SizeInBytes() + sizeof(int) + base.SizeInBytes();
        }
 
        /// <summary>
        /// Writes a binary representation of this class to a byte buffer, at a given position.
        /// The position is incremented to the end of the representation
        /// </summary>
        /// <param name="buffer">a byte array where the binary representation is written</param>
        /// <param name="position">the position in the byte array</param>
        public override void ToByteArray(byte[] buffer, ref int position)
        {
            base.ToByteArray(buffer, ref position);
            Length.ToByteArray(buffer, ref position);
            _data.ToByteArray(buffer, ref position);
        }
 
        public override unsafe int this[int index]
        {
            get
            {
                int dataIndex = index / 2;
                bool highBits = (index % 2 == 0);
 
                byte v = _data[dataIndex];
                if (highBits)
                    v >>= 4;
                else
                    v &= 0x0f;
 
                return v;
            }
 
            set
            {
                Contracts.Assert(0 <= value && value < (1 << 4));
                byte v;
                v = (byte)value;
 
                int dataIndex = index / 2;
                bool highBits = (index % 2 == 0);
                if (highBits)
                {
                    _data[dataIndex] &= 0x0f;
                    _data[dataIndex] |= (byte)(v << 4);
                }
                else
                {
                    _data[dataIndex] &= 0xf0;
                    _data[dataIndex] |= v;
                }
            }
        }
 
        public void SumupNative(SumupInputData input, FeatureHistogram histogram)
        {
            unsafe
            {
                fixed (byte* pData = _data)
                {
                    SumupCPlusPlusDense(input, histogram, pData, 4);
                }
            }
        }
 
        public override void Sumup(SumupInputData input, FeatureHistogram histogram) => SumupHandler(input, histogram);
    }
 
    /// <summary>
    /// A class to represent features using 16 bits.
    /// </summary>
    internal sealed class Dense16BitIntArray : DenseDataCallbackIntArray
    {
        private readonly ushort[] _data;
 
        public override IntArrayBits BitsPerItem { get { return IntArrayBits.Bits16; } }
 
        public Dense16BitIntArray(int len)
            : base(len)
        {
            _data = new ushort[len];
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        public Dense16BitIntArray(int len, IEnumerable<int> values)
            : base(len)
        {
            _data = values.Select(i => (ushort)i).ToArray(len);
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        public Dense16BitIntArray(byte[] buffer, ref int position)
            : base(buffer.ToInt(ref position))
        {
            _data = buffer.ToUShortArray(ref position);
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        public override unsafe void Callback(Action<IntPtr> callback)
        {
            fixed (ushort* pData = _data)
            {
                callback((IntPtr)pData);
            }
        }
 
        /// <summary>
        /// Returns the number of bytes written by the member ToByteArray()
        /// </summary>
        public override int SizeInBytes()
        {
            return _data.SizeInBytes() + sizeof(int) + base.SizeInBytes();
        }
 
        /// <summary>
        /// Writes a binary representation of this class to a byte buffer, at a given position.
        /// The position is incremented to the end of the representation
        /// </summary>
        /// <param name="buffer">a byte array where the binary representation is written</param>
        /// <param name="position">the position in the byte array</param>
        public override void ToByteArray(byte[] buffer, ref int position)
        {
            base.ToByteArray(buffer, ref position);
            Length.ToByteArray(buffer, ref position);
            _data.ToByteArray(buffer, ref position);
        }
 
        public override unsafe int this[int index]
        {
            get
            {
                return _data[index];
            }
 
            set
            {
                Contracts.Assert(0 <= value && value <= ushort.MaxValue);
                _data[index] = (ushort)value;
            }
        }
 
        public void SumupNative(SumupInputData input, FeatureHistogram histogram)
        {
            unsafe
            {
                fixed (ushort* pData = _data)
                {
                    byte* pDataBytes = (byte*)pData;
                    SumupCPlusPlusDense(input, histogram, pDataBytes, 16);
                }
            }
        }
 
        public override void Sumup(SumupInputData input, FeatureHistogram histogram) => SumupHandler(input, histogram);
 
    }
 
    /// <summary>
    /// A class to represent features using 32 bits.
    /// </summary>
    internal sealed class Dense32BitIntArray : DenseDataCallbackIntArray
    {
        private readonly int[] _data;
 
        public override IntArrayBits BitsPerItem { get { return IntArrayBits.Bits32; } }
 
        public Dense32BitIntArray(int len)
            : base(len)
        {
            _data = new int[len];
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        public Dense32BitIntArray(int len, IEnumerable<int> values)
            : base(len)
        {
            _data = values.ToArray(len);
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        public Dense32BitIntArray(byte[] buffer, ref int position)
            : base(buffer.ToInt(ref position))
        {
            _data = buffer.ToIntArray(ref position);
            SetupSumupHandler(SumupNative, base.Sumup);
        }
 
        public override unsafe void Callback(Action<IntPtr> callback)
        {
            fixed (int* pData = _data)
            {
                callback((IntPtr)pData);
            }
        }
 
        /// <summary>
        /// Returns the number of bytes written by the member ToByteArray()
        /// </summary>
        public override int SizeInBytes()
        {
            return _data.SizeInBytes() + sizeof(int) + base.SizeInBytes();
        }
 
        /// <summary>
        /// Writes a binary representation of this class to a byte buffer, at a given position.
        /// The position is incremented to the end of the representation
        /// </summary>
        /// <param name="buffer">a byte array where the binary representation is written</param>
        /// <param name="position">the position in the byte array</param>
        public override void ToByteArray(byte[] buffer, ref int position)
        {
            base.ToByteArray(buffer, ref position);
            Length.ToByteArray(buffer, ref position);
            _data.ToByteArray(buffer, ref position);
        }
 
        public override int this[int index]
        {
            get
            {
                return _data[index];
            }
 
            set
            {
                Contracts.Assert(value >= 0);
                _data[index] = value;
            }
        }
 
        public void SumupNative(SumupInputData input, FeatureHistogram histogram)
        {
            unsafe
            {
                fixed (int* pData = _data)
                {
                    byte* pDataBytes = (byte*)pData;
                    SumupCPlusPlusDense(input, histogram, pDataBytes, 32);
                }
            }
        }
 
        public override void Sumup(SumupInputData input, FeatureHistogram histogram) => SumupHandler(input, histogram);
    }
}