File: Dataset\RepeatIntArray.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
    using FloatType = Single;
#else
    using FloatType = Double;
#endif
 
    internal sealed class DeltaRepeatIntArray : IntArray
    {
        private readonly DenseIntArray _values;
        private readonly int _length;
        private readonly byte[] _deltas;
        private readonly int _deltasActualLength;
 
        public DeltaRepeatIntArray(int length, IntArrayBits bitsPerItem, IEnumerable<int> values)
        {
            using (Timer.Time(TimerEvent.SparseConstruction))
            {
                List<int> tempValueList = new List<int>();
                List<byte> tempDeltaList = new List<byte>();
 
                _length = 0;
 
                byte delta = 0;
                int lastVal = -1;
 
                foreach (int val in values)
                {
                    if (val != lastVal || delta == byte.MaxValue)
                    {
                        tempValueList.Add(val);
                        lastVal = val;
                        if (_length != 0)
                            tempDeltaList.Add(delta);
                        delta = 0;
                    }
                    ++delta;
                    ++_length;
                }
                if (delta > 0)
                    tempDeltaList.Add(delta);
 
                if (_length != length)
                    throw Contracts.Except("Length provided to repeat vector is inconsistent with value enumeration");
 
                // It is faster not to use a 4-bit dense array here. The memory difference is minor, since it's just
                //  the sparse values that are saved on.
                // TODO: Implement a special iterator for 4-bit array, and change this code to use the iterator, which
                //          may be faster
                if (bitsPerItem == IntArrayBits.Bits0)
                    throw Contracts.Except("Use dense arrays for 0 bits");
                if (bitsPerItem <= IntArrayBits.Bits8)
                    bitsPerItem = IntArrayBits.Bits8;
 
                _values = IntArray.New(tempValueList.Count, IntArrayType.Dense, bitsPerItem, tempValueList) as DenseIntArray;
                _deltas = tempDeltaList.ToArray();
            }
        }
 
        private DeltaRepeatIntArray(DenseIntArray values, byte[] deltas, int length)
        {
            _values = values;
            _deltas = deltas;
            _length = length;
            _deltasActualLength = length;
 
            if (BitsPerItem == IntArrayBits.Bits0)
                throw Contracts.Except("Use dense arrays for 0 bits");
        }
 
        public DeltaRepeatIntArray(byte[] buffer, ref int position)
        {
            _length = buffer.ToInt(ref position);
            _deltasActualLength = position;
            _deltasActualLength = buffer.ToInt(ref _deltasActualLength);
            _deltas = buffer.ToByteArray(ref position);
            _values = IntArray.New(buffer, ref position) as DenseIntArray;
        }
 
        /// <summary>
        /// Writes a binary representation of this class to a byte buffer, at a given position.
        /// The position is incremented to the end of the representation
        /// </summary>
        /// <param name="buffer">a byte array where the binary representation is written</param>
        /// <param name="position">the position in the byte array</param>
        public override void ToByteArray(byte[] buffer, ref int position)
        {
            base.ToByteArray(buffer, ref position);
            _length.ToByteArray(buffer, ref position);
            _deltas.ToByteArray(buffer, ref position);
            _values.ToByteArray(buffer, ref position);
        }
 
        public override int SizeInBytes()
        {
            return _values.SizeInBytes() + _deltas.SizeInBytes() + sizeof(int) + base.SizeInBytes();
        }
 
        public override int Length { get { return _length; } }
 
        public override IntArray Clone(IntArrayBits bitsPerItem, IntArrayType type)
        {
            return IntArray.New(_length, type, bitsPerItem, this);
        }
 
        public override IntArrayBits BitsPerItem { get { return _values.BitsPerItem; } }
 
        public override IntArrayType Type { get { return IntArrayType.Repeat; } }
 
        public override IEnumerator<int> GetEnumerator()
        {
            for (int i = 0; i < _deltas.Length; ++i)
            {
                int val = _values[i];
                for (int j = _deltas[i]; j > 0; --j)
                    yield return val;
            }
        }
 
        private class DeltaRepeatIntArrayIndexer : IIntArrayForwardIndexer
        {
            private readonly DeltaRepeatIntArray _array;
            private int _pos;
            private int _nextIndex; // next different index
            private int _lastVal; // The last value.
 
            public DeltaRepeatIntArrayIndexer(DeltaRepeatIntArray array)
            {
                _array = array;
                if (array._deltas.Length > 0)
                {
                    _pos = 0;
                    _nextIndex = array._deltas[0];
                    _lastVal = array._values[0];
                }
                else
                {
                    _pos = -1;
                    _nextIndex = 0;
                }
            }
 
            #region IIntArrayForwardIndexer Members
 
            public unsafe int this[int virtualIndex]
            {
                get
                {
                    if (virtualIndex < _nextIndex)
                        return _lastVal;
 
                    fixed (byte* pDeltas = _array._deltas)
                    {
                        while (++_pos < _array._values.Length)
                        {
                            //_index = _nextIndex;
                            _nextIndex += pDeltas[_pos];
                            if (virtualIndex < _nextIndex)
                                return (_lastVal = _array._values[_pos]);
                        }
                    }
                    return -1;
                }
            }
 
            #endregion
        }
 
        /// <summary>
        /// Clone an IntArray containing only the items indexed by <paramref name="itemIndices"/>
        /// </summary>
        /// <param name="itemIndices"> item indices will be contained in the cloned IntArray  </param>
        /// <returns> The cloned IntArray </returns>
        public override IntArray Clone(int[] itemIndices)
        {
            IIntArrayForwardIndexer indexer = GetIndexer();
 
            return new DeltaRepeatIntArray(itemIndices.Length, BitsPerItem, itemIndices.Select(i => indexer[i]));
        }
 
        public override IIntArrayForwardIndexer GetIndexer()
        {
            return new DeltaRepeatIntArrayIndexer(this);
        }
 
        public override IntArray[] Split(int[][] assignment)
        {
            return assignment.Select(a =>
            {
                IIntArrayForwardIndexer indexer = GetIndexer();
                return new DeltaRepeatIntArray(a.Length, BitsPerItem, a.Select(i => indexer[i]));
            }).ToArray();
        }
 
        public override void Sumup(SumupInputData input, FeatureHistogram histogram)
        {
            using (Timer.Time(TimerEvent.SumupRepeat))
            {
                if (input.DocIndices == null)
                {
                    SumupRoot(input, histogram);
                }
                else
                {
                    SumupLeaf(input, histogram);
                }
            }
        }
 
        private unsafe void SumupRoot(SumupInputData input, FeatureHistogram histogram)
        {
            fixed (FloatType* pOutputsFixed = input.Outputs)
            fixed (FloatType* pSumTargetsFixed = histogram.SumTargetsByBin)
            fixed (double* pWeightsFixed = input.Weights)
            fixed (double* pSumWeightsFixed = histogram.SumWeightsByBin)
            {
                FloatType* pOutputs = pOutputsFixed;
                double* pWeights = pWeightsFixed;
                for (int i = 0; i < _values.Length; i++)
                {
                    int featureBin = _values[i];
                    //FloatType* pSumTargets = pSumTargetsFixed + featureBin;
                    FloatType subsum = pSumTargetsFixed[featureBin];
 
                    for (int j = 0; j < _deltas[i]; ++j)
                        subsum += pOutputs[j];
                    pSumTargetsFixed[featureBin] = subsum;
                    if (pWeightsFixed != null)
                    {
                        double subweightsum = pSumWeightsFixed[featureBin];
                        for (int j = 0; j < _deltas[i]; ++j)
                            subweightsum += pWeights[j];
                        pSumWeightsFixed[featureBin] = subweightsum;
                        pWeights += _deltas[i];
                    }
                    pOutputs += _deltas[i];
                    histogram.CountByBin[featureBin] += _deltas[i];
                }
            }
        }
 
        private unsafe void SumupLeaf(SumupInputData input, FeatureHistogram histogram)
        {
            if (_length == 0)
                return;
            int nextStep = _deltas[0];
            int pos = 0;
 
            fixed (int* pDocIndicesFixed = input.DocIndices)
            fixed (FloatType* pOutputsFixed = input.Outputs)
            fixed (double* pWeightsFixed = input.Weights)
            fixed (double* pSumWeightsFixed = histogram.SumWeightsByBin)
            {
                int* pdoc = pDocIndicesFixed;
                int* end = pDocIndicesFixed + input.TotalCount;
                FloatType* pOutputs = pOutputsFixed;
                double* pWeights = pWeightsFixed;
                while (pdoc < end)
                {
                    while (nextStep <= *pdoc)
                        nextStep += _deltas[++pos];
                    int bin = _values[pos];
                    int count = 0;
                    FloatType subsum = histogram.SumTargetsByBin[bin];
                    if (pWeightsFixed != null)
                    {
                        double subweightsum = histogram.SumWeightsByBin[bin];
                        while (pdoc < end && nextStep > *pdoc)
                        {
                            subsum += *(pOutputs++);
                            subweightsum += *(pWeights++);
                            count++;
                            pdoc++;
                        }
                        histogram.SumWeightsByBin[bin] = subweightsum;
                    }
                    else
                    {
                        while (pdoc < end && nextStep > *pdoc)
                        {
                            subsum += *(pOutputs++);
                            count++;
                            pdoc++;
                        }
                    }
                    histogram.SumTargetsByBin[bin] = subsum;
                    histogram.CountByBin[bin] += count;
                }
            }
        }
    }
}