File: Dataset\SegmentIntArray.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Security;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
    using FloatType = System.Single;
#else
    using FloatType = System.Double;
#endif
 
    internal sealed class SegmentIntArray : IntArray
    {
        private readonly byte[] _segType;
        private readonly int[] _segLength;
        private readonly uint[] _data;
        private readonly int _length;
        private readonly IntArrayBits _bpi;
 
        /// <summary>
        /// The cost of a transition between segments in bits.
        /// </summary>
        public const long TransitionCost = (sizeof(byte) + sizeof(int)) << 3;
        public const ushort U16TransitionCost = (ushort)TransitionCost;
 
        public override IntArrayBits BitsPerItem
        {
            get { return _bpi; }
        }
 
        public override IntArrayType Type
        {
            get { return IntArrayType.Segmented; }
        }
 
        // Delegate defintions so we can store a reference to the native or managed method so we only have to check it once.
        public delegate void PerformSegmentFindOptimalPath(uint[] array, int len, int bitsNeeded, out long bits, out int transitions);
        public delegate void PerformSegmentFindOptimalCost(uint[] array, int len, int bitsNeeded, out long bits);
 
        /// <summary>
        /// Used so we can set either the native or managed SegmentFindOptimalCost method one time and then
        /// never have to check again.
        /// </summary>
        public static Lazy<PerformSegmentFindOptimalCost> SegmentFindOptimalCost = new(() =>
        {
            if (UseFastTreeNative)
                return NativeSegmentFindOptimalCost;
            else
                return ManagedSegmentFindOptimalCost;
        });
 
        /// <summary>
        /// Used so we can set either the native or managed SegmentFindOptimalPath method one time and then
        /// never have to check again.
        /// </summary>
        public static Lazy<PerformSegmentFindOptimalPath> SegmentFindOptimalPath = new(() =>
        {
            if (UseFastTreeNative)
                return NativeSegmentFindOptimalPath;
            else
                return ManagedSegmentFindOptimalPath;
        });
 
        public SegmentIntArray(int length, IEnumerable<int> values)
        {
            using (Timer.Time(TimerEvent.SparseConstruction))
            {
                SetupSumupHandler(SumupCPlusPlus, base.Sumup);
 
                uint[] vals = new uint[length];
                uint pos = 0;
                uint max = 0;
                foreach (int v in values)
                {
                    if (pos >= length)
                    {
                        throw Contracts.Except("Length provided to segment vector is inconsistent with value enumeration");
                    }
                    vals[pos++] = (uint)v;
                    if ((uint)v > max)
                        max = (uint)v;
                }
                if (pos != length)
                {
                    throw Contracts.Except("Length provided to segment vector is inconsistent with value enumeration");
                }
 
                int maxbits = BitsForValue(max);
                int transitions;
                long bits;
                SegmentFindOptimalPath.Value(vals, vals.Length, maxbits, out bits, out transitions);
                var b = FromWorkArray(vals, vals.Length, bits, transitions);
                _segType = b._segType;
                _segLength = b._segLength;
                _data = b._data;
                _length = b._length;
                _bpi = b._bpi;
            }
        }
 
        public SegmentIntArray(byte[] buffer, ref int position)
        {
            _bpi = (IntArrayBits)(buffer.ToInt(ref position));
            _length = buffer.ToInt(ref position);
            _segType = buffer.ToByteArray(ref position);
            _segLength = buffer.ToIntArray(ref position);
            _data = buffer.ToUIntArray(ref position);
        }
 
        private SegmentIntArray(byte[] segType, int[] segLen, uint[] data, int len)
        {
            _segType = segType;
            _segLength = segLen;
            _data = data;
            _length = len;
            _bpi = IntArrayBits.Bits32;
        }
 
        public override void ToByteArray(byte[] buffer, ref int position)
        {
            base.ToByteArray(buffer, ref position);
            ((int)_bpi).ToByteArray(buffer, ref position);
            _length.ToByteArray(buffer, ref position);
            _segType.ToByteArray(buffer, ref position);
            _segLength.ToByteArray(buffer, ref position);
            _data.ToByteArray(buffer, ref position);
        }
 
        public override int SizeInBytes()
        {
            return base.SizeInBytes() + sizeof(int) + sizeof(int) +
                _segType.SizeInBytes() + _segLength.SizeInBytes() +
                _data.SizeInBytes();
        }
 
        private int Get(long offset, byte bits)
        {
            return Get(offset, ~((uint)((-1) << bits)));
        }
 
        private int Get(long offset, uint mask)
        {
            int minor = (int)(offset & 0x1f);
            int major = (int)(offset >> 5);
            return (int)((uint)((((ulong)_data[major] >> minor) | ((((ulong)_data[major + 1]) << 32) >> minor))) & mask);
        }
 
        public static int BitsForValue(uint val)
        {
            int firstvalid;
            for (firstvalid = 0; val > 0; val >>= 1, firstvalid++)
                ;
            return firstvalid;
        }
 
        /// <summary>
        /// Finds the bits necessary for the optimal variable bit encoding of this
        /// array. If we are also finding the actual optimal path, it can only work
        ///
        /// This is a considerably less efficiency managed analogue to the
        /// C_SegmentFindOptimalPath and C_SegmentFindOptimalCost functions.
        /// It is used by the class only when not using the unmanaged library.
        /// </summary>
        /// <param name="ivalues">The values for which we should find the optimal cost. If
        /// findPath is active, the most significant 5 bits will be used to store the bitness
        /// with which this path should be chosen.</param>
        /// <param name="bitsForMaxItem">This should be the maximum number of bits necessary
        /// to encode the largest item in that array, or a higher value. Owing to the nature
        /// of the values as 32 bit quantities this value should be in the range [0,32], or
        /// 21 if we are finding the
        /// cannot exceed 31.</param>
        /// <param name="findPath">Whether we should find the best path, by also storing the
        /// optimal path in the most 5 significant digits.</param>
        /// <param name="bits">The number of bits necessary for the optimal encoding.</param>
        /// <param name="transitions">The number of transitions necessary in the
        /// optimal encoding (only if findPath is true).</param>
        /// <param name="max">The maximum element in the ivalues array.</param>
        public static void StatsOfBestEncoding(uint[] ivalues, int bitsForMaxItem, bool findPath, out long bits, out int transitions, out uint max)
        {
            // The cost of the state.
            byte[] state = new byte[bitsForMaxItem + 1];
            byte firstvalid;
 
            if (bitsForMaxItem > 32 || bitsForMaxItem < 0)
                throw Contracts.Except("Bits for max item must be in range [0,32], {0} is illegal", bitsForMaxItem);
            else if (bitsForMaxItem > 21 && findPath)
                throw Contracts.Except("Cannot use more than 21 bits if also storing the actual optimal path");
 
            max = 0;
            bits = TransitionCost;
            for (int i = 0; i < ivalues.Length; ++i)
            {
                uint val = (uint)ivalues[i];
                if (val > max)
                    max = val;
                uint transmap = 0;
                for (firstvalid = 0; val > 0; val >>= 1, firstvalid++)
                {
                    state[firstvalid] = 0xff;
                }
                byte beststate = 0;
                byte bestcost = 0xff;
                for (byte b = firstvalid; b <= bitsForMaxItem; ++b)
                {
                    if (state[b] <= TransitionCost)
                    {
                        // We should stay.
                        state[b] += b;
                    }
                    else
                    {
                        // We should transition.
                        state[b] = (byte)(TransitionCost + b);
                        transmap |= (uint)(1 << b);
                    }
                    if (bestcost > state[b])
                    {
                        bestcost = state[beststate = b];
                    }
                }
                for (byte b = firstvalid; b <= bitsForMaxItem; ++b)
                {
                    state[b] -= bestcost;
                }
                bits += bestcost;
                if (findPath)
                {
                    ivalues[i] = ((((uint)beststate) << 27) | (((uint)firstvalid) << 22) | transmap | (uint)ivalues[i]);
                }
            }
            if (bitsForMaxItem < 32 && (((uint)1) << bitsForMaxItem) <= max)
            {
                throw Contracts.Except(
                    "Maximum specified bits {0} was not actually sufficient to encode maximum value {1}",
                    bitsForMaxItem, max);
            }
            transitions = 0;
            if (findPath)
            {
                int back = 1;
                int bitness = 0;
                for (int i = ivalues.Length - 1; i >= 0; --i)
                {
                    bitness = back != 0 ? (int)(ivalues[i] >> 27) : bitness;
                    transitions += back;
                    back = (int)((ivalues[i] >> bitness) & 1);
                    ivalues[i] &= (uint)((1 << ((int)(ivalues[i] >> 22))) - 1);
                    ivalues[i] |= (uint)(bitness << 27);
                }
            }
        }
 
        public override IntArray Clone(IntArrayBits bitsPerItem, IntArrayType type)
        {
            throw Contracts.ExceptNotImpl();
        }
 
        public override IEnumerator<int> GetEnumerator()
        {
            long boffset = 0;
            for (int s = 0; s < _segType.Length; ++s)
            {
                int segLen = _segLength[s];
                byte segType = _segType[s];
                if (segType == 0)
                {
                    while (segLen-- > 0)
                    {
                        // This tiny optimization makes a *huge* difference
                        // for our often sparse features.
                        yield return 0;
                    }
                }
                else
                {
                    while (segLen-- > 0)
                    {
                        yield return Get(boffset, segType);
                        boffset += segType;
                    }
                }
            }
        }
 
        public override IIntArrayForwardIndexer GetIndexer()
        {
            return new SegmentIntArrayIndexer(this);
        }
 
        public override int Length
        {
            get { return _length; }
        }
 
        public override IntArray[] Split(int[][] assignment)
        {
            return assignment.Select(a =>
            {
                SegmentIntArrayIndexer ind = GetIndexer() as SegmentIntArrayIndexer;
                return new SegmentIntArray(a.Length, a.Select(i => ind[i]));
            }).ToArray();
        }
 
        /// <summary>
        /// Clone an IntArray containing only the items indexed by <paramref name="itemIndices"/>
        /// </summary>
        /// <param name="itemIndices"> item indices will be contained in the cloned IntArray  </param>
        /// <returns> The cloned IntArray </returns>
        public override IntArray Clone(int[] itemIndices)
        {
            SegmentIntArrayIndexer indexer = GetIndexer() as SegmentIntArrayIndexer;
 
            return new SegmentIntArray(itemIndices.Length, itemIndices.Select(i => indexer[i]));
        }
 
        private class SegmentIntArrayIndexer : IIntArrayForwardIndexer
        {
            private readonly SegmentIntArray _array;
            private int _nextIndex; // index where the next segment begins
 
            private long _currentBit; // the bit offset
            private int _currentIndex; // the index where the current segment begins
            private byte _currentType; // the type of the current segment
            private int _currentSegment;
 
            public SegmentIntArrayIndexer(SegmentIntArray array)
            {
                _array = array;
                _currentSegment = 0;
                _currentBit = 0;
 
                if (_array._segType.Length > 0)
                {
                    _currentIndex = 0;
                    _currentType = _array._segType[0];
                    _nextIndex = _array._segLength[0];
                }
                else
                {
                    // Handle the edge case where we have a completely empty array.
                    _currentIndex = _array.Length;
                    _currentType = 0;
                    _nextIndex = _currentIndex;
                }
            }
 
            #region IIntArrayForwardIndexer Members
 
            public unsafe int this[int virtualIndex]
            {
                get
                {
                    while (_nextIndex <= virtualIndex)
                    {
                        _currentBit += (_nextIndex - _currentIndex) * _currentType;
                        _currentIndex = _nextIndex;
                        _currentType = _array._segType[++_currentSegment];
                        _nextIndex += _array._segLength[_currentSegment];
                    }
                    long bitoffset = _currentBit + (virtualIndex - _currentIndex) * _currentType;
                    int major = (int)(bitoffset >> 5);
                    return (int)(((long)_array._data[major] | (((long)_array._data[major + 1]) << 32)) >> (int)(bitoffset & 0x1f)) & ((1 << _currentType) - 1);
                }
            }
 
            #endregion
        }
 
        public static SegmentIntArray FromWorkArray(uint[] workArray, int len, long bits, int transitions)
        {
            long databits = bits - (long)transitions * (long)TransitionCost;
            byte[] st = new byte[transitions];
            int[] sl = new int[transitions];
            uint[] data = new uint[(databits >> 5) + 2];
 
            int curroffset = 0;
            int localoffset = 0;
            int lastbits = -1;
            int runlen = 0;
            int segoffset = 0;
            ulong currdata = 0;
 
            for (int i = 0; i < len; ++i)
            {
                uint val = workArray[i];
                currdata |= (((ulong)(val & 0x07ffffff)) << localoffset);
                int thisbits = (int)(val >> 27);
                localoffset += thisbits;
                if (localoffset >= 32)
                {
                    data[curroffset++] = (uint)currdata;
                    localoffset -= 32;
                    currdata >>= 32;
                }
                if (lastbits != thisbits)
                {
                    st[segoffset++] = (byte)thisbits;
                    if (runlen > 0)
                        sl[segoffset - 2] = runlen;
 
                    lastbits = thisbits;
                    runlen = 0;
                }
                runlen++;
            }
            if (runlen > 0)
                sl[segoffset - 1] = runlen;
            data[curroffset] = (uint)currdata;
            data[curroffset + 1] = (uint)(currdata >> 32);
 
            return new SegmentIntArray(st, sl, data, len);
        }
 
        public static void NativeSegmentFindOptimalPath(uint[] array, int len, int bitsNeeded, out long bits, out int transitions)
        {
            if (bitsNeeded <= 15)
            {
                SegmentFindOptimalPath15(array, len, out bits, out transitions);
            }
            else if (bitsNeeded <= 21)
            {
                SegmentFindOptimalPath21(array, len, out bits, out transitions);
            }
            else if (bitsNeeded <= 31)
            {
                throw Contracts.ExceptNotImpl("Segment array pathfinder currently does not support more than 21 bits");
            }
            else
            {
                throw Contracts.Except("Segment array cannot represent more than 31 bits");
            }
        }
 
        public static void NativeSegmentFindOptimalCost(uint[] array, int len, int bitsNeeded, out long bits)
        {
            if (bitsNeeded <= 15)
            {
                SegmentFindOptimalCost15(array, len, out bits);
            }
            else if (bitsNeeded <= 31)
            {
                SegmentFindOptimalCost31(array, len, out bits);
            }
            else
            {
                throw Contracts.Except("Segment array cannot represent more than 31 bits");
            }
        }
 
        public static unsafe void SegmentFindOptimalPath7(uint[] array, int len, out long bits, out int transitions)
        {
            long b = 0;
            int t = 0;
            fixed (uint* pArray = array)
            {
                bits = 0;
                C_SegmentFindOptimalPath7(pArray, len, &b, &t);
            }
            bits = b;
            transitions = t;
        }
 
        public static unsafe void SegmentFindOptimalPath15(uint[] array, int len, out long bits, out int transitions)
        {
            long b = 0;
            int t = 0;
            fixed (uint* pArray = array)
            {
                bits = 0;
                C_SegmentFindOptimalPath15(pArray, len, &b, &t);
            }
            bits = b;
            transitions = t;
        }
 
        public static unsafe void SegmentFindOptimalPath21(uint[] array, int len, out long bits, out int transitions)
        {
            long b = 0;
            int t = 0;
            fixed (uint* pArray = array)
            {
                bits = 0;
                C_SegmentFindOptimalPath21(pArray, len, &b, &t);
            }
            bits = b;
            transitions = t;
        }
 
        public static unsafe void SegmentFindOptimalCost15(uint[] array, int len, out long bits)
        {
            long b = 0;
            fixed (uint* pArray = array)
            {
                bits = 0;
                C_SegmentFindOptimalCost15(pArray, len, &b);
            }
            bits = b;
        }
 
        public static unsafe void SegmentFindOptimalCost31(uint[] array, int len, out long bits)
        {
            long b = 0;
            fixed (uint* pArray = array)
            {
                bits = 0;
                C_SegmentFindOptimalCost31(pArray, len, &b);
            }
            bits = b;
        }
        internal const string NativePath = "FastTreeNative";
#pragma warning disable TLC_GeneralName // Externs follow their own rules.
        [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
        private static extern unsafe void C_SegmentFindOptimalPath21(uint* valv, int valc, long* pBits, int* pTransitions);
 
        [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
        private static extern unsafe void C_SegmentFindOptimalPath15(uint* valv, int valc, long* pBits, int* pTransitions);
 
        [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
        private static extern unsafe void C_SegmentFindOptimalPath7(uint* valv, int valc, long* pBits, int* pTransitions);
 
        [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
        private static extern unsafe void C_SegmentFindOptimalCost15(uint* valv, int valc, long* pBits);
 
        [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
        private static extern unsafe void C_SegmentFindOptimalCost31(uint* valv, int valc, long* pBits);
 
        [DllImport(NativePath)]
        private static extern unsafe int C_SumupSegment_float(
            uint* pData, byte* pSegType, int* pSegLength, int* pIndices,
            float* pSampleOutputs, double* pSampleOutputWeights,
            float* pSumTargetsByBin, double* pSumWeightsByBin,
            int* pCountByBin, int totalCount, double totalSampleOutputs);
 
        [DllImport(NativePath)]
        private static extern unsafe int C_SumupSegment_double(
            uint* pData, byte* pSegType, int* pSegLength, int* pIndices,
            double* pSampleOutputs, double* pSampleOutputWeights,
            double* pSumTargetsByBin, double* pSumWeightsByBin,
            int* pCountByBin, int totalCount, double totalSampleOutputs);
#pragma warning restore TLC_GeneralName
 
        public unsafe void SumupCPlusPlus(SumupInputData input, FeatureHistogram histogram)
        {
            using (Timer.Time(TimerEvent.SumupSegment))
            {
                fixed (FloatType* pSumTargetsByBin = histogram.SumTargetsByBin)
                fixed (FloatType* pSampleOutputs = input.Outputs)
                fixed (double* pSumWeightsByBin = histogram.SumWeightsByBin)
                fixed (double* pSampleOuputWeights = input.Weights)
                fixed (uint* pData = _data)
                fixed (byte* pSegType = _segType)
                fixed (int* pSegLength = _segLength)
                fixed (int* pIndices = input.DocIndices)
                fixed (int* pCountByBin = histogram.CountByBin)
                {
                    int rv =
#if USE_SINGLE_PRECISION
                        C_SumupSegment_float
#else
                        C_SumupSegment_double
#endif
                            (pData, pSegType, pSegLength, pIndices, pSampleOutputs, pSampleOuputWeights,
                             pSumTargetsByBin,
                             pSumWeightsByBin, pCountByBin, input.TotalCount,
                             input.SumTargets);
                    if (rv < 0)
                        throw Contracts.Except("CSumup returned error {0}", rv);
                }
            }
        }
        public static void ManagedSegmentFindOptimalPath(uint[] array, int len, int bitsNeeded, out long bits, out int transitions)
        {
            uint max;
            StatsOfBestEncoding(array, bitsNeeded, true, out bits, out transitions, out max);
        }
 
        public static void ManagedSegmentFindOptimalCost(uint[] array, int len, int bitsNeeded, out long bits)
        {
            int transitions;
            uint max;
            StatsOfBestEncoding(array, bitsNeeded, false, out bits, out transitions, out max);
        }
 
        public override void Sumup(SumupInputData input, FeatureHistogram histogram)
        {
            using (Timer.Time(TimerEvent.SumupSegment))
            {
                if (_length == 0)
                    return;
                SumupHandler(input, histogram);
            }
        }
 
    }
}