File: Dataset\SparseIntArray.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Security;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
    using FloatType = System.Single;
#else
    using FloatType = System.Double;
#endif
 
    /// <summary>
    /// This <see cref="IntArray"/> implementation represents a sequence of values using parallel
    /// arrays of both values, as well as deltas indicating the number of values to the next
    /// explicit value. Values "between" these deltas are implicitly zero.
    ///
    /// Note that it is possible to misuse the deltas by making some of them themselves 0, allowing
    /// us to represent multiple values per row. In this case, <see cref="IntArray.GetEnumerator"/>
    /// and <see cref="IntArray.GetIndexer"/> will not have sensible values, but
    /// <see cref="IntArray.Sumup"/> will work and accumulate the same target into multiple bins.
    /// This trick should be used with caution, and is useful for the <see cref="NHotFeatureFlock"/>.
    /// </summary>
    internal sealed class DeltaSparseIntArray : IntArray
    {
        private readonly DenseIntArray _values;
        private readonly byte[] _deltas;
        private readonly int _length;
 
        public DeltaSparseIntArray(int length, IntArrayBits bitsPerItem)
        {
            if (bitsPerItem == IntArrayBits.Bits0)
                throw Contracts.Except("Use dense arrays for 0 bits");
            if (bitsPerItem <= IntArrayBits.Bits8)
                bitsPerItem = IntArrayBits.Bits8;
 
            _length = length;
        }
 
        public DeltaSparseIntArray(int length, IntArrayBits bitsPerItem, IEnumerable<int> values)
        {
            using (Timer.Time(TimerEvent.SparseConstruction))
            {
                List<int> tempValueList = new List<int>();
                List<byte> tempDeltaList = new List<byte>();
 
                _length = 0;
 
                byte delta = 0;
 
                foreach (int val in values)
                {
                    if (val != 0 || delta == byte.MaxValue)
                    {
                        tempDeltaList.Add(delta);
                        tempValueList.Add(val);
                        _length += delta;
                        delta = 0;
                    }
                    ++delta;
                }
 
                _length += delta;
 
                if (_length != length)
                    throw Contracts.Except("Length provided to sparse vector is inconsistent with value enumeration");
 
                // It is faster not to use a 4-bit dense array here. The memory difference is minor, since it's just
                //  the sparse values that are saved on.
                // TODO: Implement a special iterator for 4-bit array, and change this code to use the iterator, which
                //          may be faster
                if (bitsPerItem == IntArrayBits.Bits0)
                    throw Contracts.Except("Use dense arrays for 0 bits");
                if (bitsPerItem <= IntArrayBits.Bits8)
                    bitsPerItem = IntArrayBits.Bits8;
 
                _values = IntArray.New(tempValueList.Count, IntArrayType.Dense, bitsPerItem, tempValueList) as DenseIntArray;
                _deltas = tempDeltaList.ToArray();
            }
        }
 
        public DeltaSparseIntArray(byte[] buffer, ref int position)
        {
            _length = buffer.ToInt(ref position);
            // REVIEW: The two lines below is as it actually appeared. I have no earthly idea of what
            // it was trying to accomplish. It appears to function as a no-op, resulting in no valid results
            // inside _deltasActualLength.
            //_deltasActualLength = position;
            //_deltasActualLength = buffer.ToInt(ref _deltasActualLength);
            _deltas = buffer.ToByteArray(ref position);
            _values = IntArray.New(buffer, ref position) as DenseIntArray;
        }
 
        public DeltaSparseIntArray(DenseIntArray values, byte[] deltas, int length)
        {
            Contracts.AssertValue(values);
            Contracts.AssertValue(deltas);
            Contracts.Assert(values.Length == deltas.Length);
            Contracts.Assert(deltas.Sum(d => (long)d) < length);
 
            _values = values;
            _deltas = deltas;
            _length = length;
 
            if (BitsPerItem == IntArrayBits.Bits0)
                throw Contracts.Except("Use dense arrays for 0 bits");
        }
 
        /// <summary>
        /// Construct a sparse int array from index, value pairs.
        /// </summary>
        /// <param name="length">The total length of the constructed array.</param>
        /// <param name="bitsPerItem">The number of bits required to store the values.</param>
        /// <param name="nonZeroValues">An ordered enumerable of (index,value) pairs.
        /// Each index should be strictly increasing as the iterable proceeds.</param>
        public DeltaSparseIntArray(int length, IntArrayBits bitsPerItem, IEnumerable<KeyValuePair<int, int>> nonZeroValues)
        {
            using (Timer.Time(TimerEvent.SparseConstruction))
            {
                List<int> tempValueList = new List<int>();
                List<byte> tempDeltaList = new List<byte>();
 
                int currentIndex = 0;
                foreach (KeyValuePair<int, int> pair in nonZeroValues)
                {
                    int index = pair.Key;
                    int value = pair.Value;
                    if (index <= currentIndex && (index < 0 || tempValueList.Count > 0))
                        throw Contracts.Except("index {0} occurred after {1}", index, currentIndex);
                    while (index - currentIndex > byte.MaxValue)
                    {
                        tempDeltaList.Add(byte.MaxValue);
                        tempValueList.Add(0);
                        currentIndex += byte.MaxValue;
                    }
                    tempDeltaList.Add((byte)(index - currentIndex));
                    tempValueList.Add(value);
                    currentIndex = index;
                }
                // Add the final chunks of 0's if it ended early
                while (length - currentIndex > byte.MaxValue)
                {
                    tempDeltaList.Add(byte.MaxValue);
                    tempValueList.Add(0);
                    currentIndex += byte.MaxValue;
                }
                if (currentIndex >= length && currentIndex > 0)
                    throw Contracts.Except("Index {0} inconsistent with length {1}", currentIndex, length);
                _length = length;
 
                // It is faster not to use a 4-bit dense array here. The memory difference is minor, since it's just
                //  the sparse values that are saved on.
                // TODO: Implement a special iterator for 4-bit array, and change this code to use the iterator, which
                //          may be faster
                if (bitsPerItem == IntArrayBits.Bits0)
                    throw Contracts.Except("Use dense arrays for 0 bits");
                if (bitsPerItem <= IntArrayBits.Bits8)
                    bitsPerItem = IntArrayBits.Bits8;
 
                _values = IntArray.New(tempValueList.Count, IntArrayType.Dense, bitsPerItem, tempValueList) as DenseIntArray;
                _deltas = tempDeltaList.ToArray();
            }
        }
 
        public struct IndexPair
        {
            public int Index;
            public int Value;
        }
 
        public IEnumerable<IndexPair> GetIndexPairs()
        {
            IndexPair pair = new IndexPair();
            int currIndex = 0;
            int sparseIndex = 0;
            foreach (int val in _values)
            {
                currIndex += _deltas[sparseIndex++];
                if (val != 0)
                {
                    pair.Index = currIndex;
                    pair.Value = val;
                    yield return pair;
                }
            }
        }
 
        public int DeltaLength => _deltas.Length;
 
        #region IIntArray Members
 
        public override int Length => _length;
 
        public override IntArray[] Split(int[][] assignment)
        {
            IntArray[] parts = new IntArray[assignment.Length];
            for (int i = 0; i < assignment.Length; ++i)
            {
                IIntArrayForwardIndexer indexer = GetIndexer();
                parts[i] = IntArray.New(assignment[i].Length, IntArrayType.Sparse, BitsPerItem, assignment[i].Select(x => indexer[x]));
            }
 
            return parts;
        }
 
        /// <summary>
        /// Clone an IntArray containing only the items indexed by <paramref name="itemIndices"/>
        /// </summary>
        /// <param name="itemIndices"> item indices will be contained in the cloned IntArray  </param>
        /// <returns> The cloned IntArray </returns>
        public override IntArray Clone(int[] itemIndices)
        {
            IIntArrayForwardIndexer indexer = GetIndexer();
 
            return IntArray.New(itemIndices.Length, IntArrayType.Sparse, BitsPerItem, itemIndices.Select(x => indexer[x]));
        }
 
        /// <summary>
        /// Returns the number of bytes written by the member ToByteArray()
        /// </summary>
        public override int SizeInBytes()
        {
            return _values.SizeInBytes()
                + _deltas.SizeInBytes()
                + sizeof(int) + base.SizeInBytes();
        }
 
        /// <summary>
        /// Writes a binary representation of this class to a byte buffer, at a given position.
        /// The position is incremented to the end of the representation
        /// </summary>
        /// <param name="buffer">a byte array where the binary representation is written</param>
        /// <param name="position">the position in the byte array</param>
        public override void ToByteArray(byte[] buffer, ref int position)
        {
            base.ToByteArray(buffer, ref position);
            _length.ToByteArray(buffer, ref position);
            _deltas.ToByteArray(buffer, ref position);
            _values.ToByteArray(buffer, ref position);
        }
 
        public override IntArrayBits BitsPerItem
        {
            get { return _values.BitsPerItem; }
        }
 
        public override IntArrayType Type { get { return IntArrayType.Sparse; } }
 
        public override IntArray Clone(IntArrayBits bitsPerItem, IntArrayType type)
        {
            if (type == IntArrayType.Sparse || type == IntArrayType.Current)
            {
                if (bitsPerItem <= IntArrayBits.Bits8)
                    bitsPerItem = IntArrayBits.Bits8;
                DenseIntArray newValues = _values.Clone(bitsPerItem, IntArrayType.Dense) as DenseIntArray;
                return new DeltaSparseIntArray(newValues, _deltas, _length);
            }
            else
            {
                DenseIntArray dense = IntArray.New(Length, IntArrayType.Dense, BitsPerItem) as DenseIntArray;
                int index = 0;
                for (int i = 0; i < _values.Length; ++i)
                {
                    index += _deltas[i];
                    dense[index] = _values[i];
                }
                return dense;
            }
        }
 
        private void SumupRoot(SumupInputData input, FeatureHistogram histogram)
        {
            // Sum up the non-zero values, then subtract from total to get the zero values
            if (histogram.SumWeightsByBin != null)
            {
                SumupRootWeighted(input, histogram);
                return;
            }
            double totalOutput = 0;
            int currentPos = 0;
 
            for (int i = 0; i < _values.Length; i++)
            {
                currentPos += _deltas[i];
 
                int featureBin = _values[i];
                FloatType output = (FloatType)input.Outputs[currentPos];
 
                histogram.SumTargetsByBin[featureBin] = (FloatType)(histogram.SumTargetsByBin[featureBin] + output);
                ++histogram.CountByBin[featureBin];
                totalOutput += output;
            }
            // Fixup the zeros. There were some zero items already placed in the zero-th entry, just add the remainder
            histogram.SumTargetsByBin[0] += (FloatType)(input.SumTargets - totalOutput);
            histogram.CountByBin[0] += input.TotalCount - _values.Length;
        }
 
        private void SumupRootWeighted(SumupInputData input, FeatureHistogram histogram)
        {
            // Sum up the non-zero values, then subtract from total to get the zero values
            Contracts.Assert(histogram.SumWeightsByBin != null);
            Contracts.Assert(input.Weights != null);
            double totalOutput = 0;
            double totalWeight = 0;
            int currentPos = 0;
 
            for (int i = 0; i < _values.Length; i++)
            {
                currentPos += _deltas[i];
 
                int featureBin = _values[i];
                FloatType output = (FloatType)input.Outputs[currentPos];
                FloatType weight = (FloatType)input.Weights[currentPos];
 
                histogram.SumTargetsByBin[featureBin] = (FloatType)(histogram.SumTargetsByBin[featureBin] + output);
                histogram.SumWeightsByBin[featureBin] = (FloatType)(histogram.SumWeightsByBin[featureBin] + weight);
                ++histogram.CountByBin[featureBin];
                totalOutput += output;
                totalWeight += weight;
            }
            // Fixup the zeros. There were some zero items already placed in the zero-th entry, just add the remainder
            histogram.SumTargetsByBin[0] += (FloatType)(input.SumTargets - totalOutput);
            histogram.SumWeightsByBin[0] += (FloatType)(input.SumWeights - totalWeight);
            histogram.CountByBin[0] += input.TotalCount - _values.Length;
        }
 
        // Fixing the arrays and using unsafe accesses may give a slight speedup, but it is hard to tell.
        // OPTIMIZE: Another two methods would be doing binary search or using a hashtable -- binary search
        //  when there are very few docs in the leaf
        private unsafe void SumupLeaf(SumupInputData input, FeatureHistogram histogram)
        {
            if (histogram.SumWeightsByBin != null)
            {
                SumupLeafWeighted(input, histogram);
                return;
            }
            int iDocIndices = 0;
            int iSparse = 0;
            int totalCount = 0;
            FloatType totalOutput = 0;
            int currentPos = _deltas.Length > 0 ? _deltas[iSparse] : _length;
 
            fixed (int* pDocIndices = input.DocIndices)
            fixed (byte* pDeltas = _deltas)
            fixed (FloatType* pOutputs = input.Outputs)
            {
                for (; ; )
                {
                    if (currentPos < pDocIndices[iDocIndices])
                    {
                        if (++iSparse >= _deltas.Length)
                            break;
                        currentPos += pDeltas[iSparse];
                    }
                    else if (currentPos > pDocIndices[iDocIndices])
                    {
                        if (++iDocIndices >= input.TotalCount)
                            break;
                    }
                    else
                    {
                        // A nonzero entry matched one of the docs in the leaf, add it to the histogram.
                        int featureBin = _values[iSparse];
                        FloatType output = pOutputs[iDocIndices];
                        histogram.SumTargetsByBin[featureBin] += output;
                        totalOutput += output;
                        ++histogram.CountByBin[featureBin];
 
                        totalCount++;
 
                        if (++iSparse >= _deltas.Length)
                            break;
 
                        // Note that if the delta is 0, we will "stay" on this document, thus
                        // allowing the sumup to work to accumulate multiple bins per document.
                        if (pDeltas[iSparse] > 0)
                        {
                            currentPos += pDeltas[iSparse];
                            if (++iDocIndices >= input.TotalCount)
                                break;
                        }
                    }
                }
            }
            // Fixup the zeros. There were some zero items already placed in the zero-th entry, just add the remainder
            histogram.SumTargetsByBin[0] += (FloatType)(input.SumTargets - totalOutput);
            histogram.CountByBin[0] += input.TotalCount - totalCount;
        }
 
        private unsafe void SumupLeafWeighted(SumupInputData input, FeatureHistogram histogram)
        {
            Contracts.Assert(histogram.SumWeightsByBin != null);
            Contracts.Assert(input.Weights != null);
 
            int iDocIndices = 0;
            int iSparse = 0;
            int totalCount = 0;
            FloatType totalOutput = 0;
            double totalWeights = 0;
            int currentPos = _deltas.Length > 0 ? _deltas[iSparse] : _length;
 
            fixed (int* pDocIndices = input.DocIndices)
            fixed (byte* pDeltas = _deltas)
            fixed (FloatType* pOutputs = input.Outputs)
            fixed (double* pWeights = input.Weights)
            {
                while (true)
                {
                    if (currentPos < pDocIndices[iDocIndices])
                    {
                        if (++iSparse >= _deltas.Length)
                            break;
                        currentPos += pDeltas[iSparse];
                    }
                    else if (currentPos > pDocIndices[iDocIndices])
                    {
                        if (++iDocIndices >= input.TotalCount)
                            break;
                    }
                    else
                    {
                        // a nonzero entry matched one of the docs in the leaf, add it to the histogram
                        int featureBin = _values[iSparse];
                        FloatType output = pOutputs[iDocIndices];
                        histogram.SumTargetsByBin[featureBin] += output;
                        totalOutput += output;
                        double weights = pWeights[iDocIndices];
                        histogram.SumWeightsByBin[featureBin] += weights;
                        totalWeights += weights;
                        ++histogram.CountByBin[featureBin];
 
                        totalCount++;
 
                        if (++iSparse >= _deltas.Length)
                            break;
 
                        if (pDeltas[iSparse] > 0)
                        {
                            currentPos += pDeltas[iSparse];
                            if (++iDocIndices >= input.TotalCount)
                                break;
                        }
                    }
                }
            }
            // Fixup the zeros. There were some zero items already placed in the zero-th entry, just add the remainder
            histogram.SumTargetsByBin[0] += (FloatType)(input.SumTargets - totalOutput);
            histogram.SumWeightsByBin[0] += (FloatType)(input.SumWeights - totalWeights);
            histogram.CountByBin[0] += input.TotalCount - totalCount;
        }
 
        public override void Sumup(SumupInputData input, FeatureHistogram histogram)
        {
            using (Timer.Time(TimerEvent.SumupSparse))
            {
                if (UseFastTreeNative)
                {
                    var callbackIntArray = _values as DenseDataCallbackIntArray;
                    if (callbackIntArray != null)
                    {
                        unsafe
                        {
                            fixed (byte* pDeltas = _deltas)
                            {
                                byte* pDeltas2 = pDeltas;
                                callbackIntArray.Callback(pValues =>
                                {
                                    SumupCPlusPlusSparse(input, histogram, (byte*)pValues, pDeltas2, _deltas.Length,
                                        _values.BitsPerItem);
                                });
                            }
                        }
                        return;
                    }
                }
 
                if (input.DocIndices == null)
                    SumupRoot(input, histogram);
                else
                    SumupLeaf(input, histogram);
            }
        }
 
        internal const string NativePath = "FastTreeNative";
        [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
        private static extern unsafe int C_SumupDeltaSparse_float(int numBits, byte* pValues, byte* pDeltas, int numDeltas, int* pIndices, float* pSampleOutputs, double* pSampleOutputWeights,
                                  float* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin,
                                  int totalCount, double totalSampleOutputs, double totalSampleOutputWeights);
 
        [DllImport(NativePath), SuppressUnmanagedCodeSecurity]
        private static extern unsafe int C_SumupDeltaSparse_double(int numBits, byte* pValues, byte* pDeltas, int numDeltas, int* pIndices, double* pSampleOutputs, double* pSampleOutputWeights,
                                  double* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin,
                                  int totalCount, double totalSampleOutputs, double totalSampleOutputWeights);
 
        private unsafe void SumupCPlusPlusSparse(SumupInputData input, FeatureHistogram histogram, byte* pValues, byte* pDeltas, int numDeltas, IntArrayBits bitsPerItem)
        {
            fixed (FloatType* pSumTargetsByBin = histogram.SumTargetsByBin)
            fixed (FloatType* pSampleOutputs = input.Outputs)
            fixed (double* pSumWeightsByBin = histogram.SumWeightsByBin)
            fixed (double* pSampleWeights = input.Weights)
            fixed (int* pIndices = input.DocIndices)
            fixed (int* pCountByBin = histogram.CountByBin)
            {
                int rv =
#if USE_SINGLE_PRECISION
                    C_SumupDeltaSparse_float
#else
                    C_SumupDeltaSparse_double
#endif
                        ((int)BitsPerItem, pValues, pDeltas, numDeltas, pIndices, pSampleOutputs, pSampleWeights,
                            pSumTargetsByBin, pSumWeightsByBin, pCountByBin,
                            input.TotalCount, input.SumTargets, input.SumWeights);
                if (rv < 0)
                    throw Contracts.Except("CSumup sumupdeltasparse {0}", rv);
            }
        }
 
        private class DeltaSparseIntArrayIndexer : IIntArrayForwardIndexer
        {
            private readonly DeltaSparseIntArray _array;
            private int _pos;
            private int _nextIndex; // Next non-zero index.
 
            public DeltaSparseIntArrayIndexer(DeltaSparseIntArray array)
            {
                Contracts.AssertValue(array);
 
                _array = array;
                if (_array._deltas.Length > 0)
                    _nextIndex = _array._deltas[0];
                else
                    _nextIndex = _array.Length;
            }
 
            #region IIntArrayForwardIndexer Members
 
            public unsafe int this[int virtualIndex]
            {
                get
                {
                    if (virtualIndex < _nextIndex)
                        return 0;
 
                    if (virtualIndex == _nextIndex)
                        return _array._values[_pos];
 
                    ++_pos;
                    fixed (byte* pDeltas = _array._deltas)
                    {
                        while (_pos < _array._values.Length)
                        {
                            _nextIndex += pDeltas[_pos];
                            if (virtualIndex < _nextIndex)
                                return 0;
                            if (virtualIndex == _nextIndex)
                                return _array._values[_pos];
 
                            ++_pos;
                        }
                    }
 
                    _nextIndex = _array._length;
                    return 0;
                }
            }
 
            #endregion
        }
 
        public override IIntArrayForwardIndexer GetIndexer()
        {
            return new DeltaSparseIntArrayIndexer(this);
        }
 
        #endregion
 
        #region IEnumerable<int> Members
 
        public override IEnumerator<int> GetEnumerator()
        {
            int curr = -1;
 
            for (int i = 0; i < _deltas.Length; ++i)
            {
                int next = i == 0 ? _deltas[i] : curr + _deltas[i];
                while (++curr < next)
                    yield return 0;
                yield return _values[i];
            }
 
            while (++curr < _length)
                yield return 0;
        }
 
        #endregion
 
    }
}