File: Dataset\FeatureHistogram.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
#if USE_SINGLE_PRECISION
    using FloatType = System.Single;
#else
    using FloatType = System.Double;
#endif
 
    /// <summary>
    /// Class to represent statistics of the feature used by LeastSquaresRegressionTreeLearner
    /// </summary>
    internal sealed class FeatureHistogram
    {
        public readonly FloatType[] SumTargetsByBin;
        public readonly double[] SumWeightsByBin;
        public readonly int[] CountByBin;
 
        public readonly int NumFeatureValues;
 
        private readonly IntArray _bins;
 
        /// <summary>
        /// Make a new FeatureHistogram
        /// </summary>
        /// <param name="bins">The bins we will be calculating sumups over</param>
        /// <param name="numBins">The number of bins, should be at least as large as the number of bins</param>
        /// <param name="useWeights">Allocates weights array when true</param>
        public FeatureHistogram(IntArray bins, int numBins, bool useWeights)
        {
            Contracts.AssertValue(bins);
            Contracts.Assert(bins.Length == 0 || (0 <= numBins && bins.Max() < numBins));
            _bins = bins;
 
            NumFeatureValues = numBins;
            SumTargetsByBin = new FloatType[NumFeatureValues];
            CountByBin = new int[NumFeatureValues];
            if (useWeights)
                SumWeightsByBin = new double[NumFeatureValues];
        }
 
        /// <summary>
        /// This function returns the estimated memory used for a FeatureHistogram object according to given
        /// number of bins.
        /// </summary>
        /// <param name="numBins">number of bins</param>
        /// <param name="hasWeights">weights array is counted when true</param>
        /// <returns>estimated size of memory used for a feature histogram object</returns>
        public static int EstimateMemoryUsedForFeatureHistogram(int numBins, bool hasWeights)
        {
            return sizeof(int) // NumberFeatureValues
                + sizeof(int) // the IsSplittable boolean value. Although sizeof(bool) is 1,
                              // but we just estimate it as 4 for alignment
                + 8 // size of reference to _feature in 64 bit machines.
                + sizeof(int) * numBins // CountByBin
                + sizeof(FloatType) * numBins // SumTargetsByBin
                + (hasWeights ? sizeof(double) * numBins : 0); // SumWeightsByBin
        }
 
        /// <summary>
        /// Subtract from myself the counts of the child histogram
        /// </summary>
        /// <param name="child">Another histogram to subtract</param>
        public unsafe void Subtract(FeatureHistogram child)
        {
            if (child.NumFeatureValues != NumFeatureValues)
                throw Contracts.Except("cannot subtract FeatureHistograms of different lengths");
 
            fixed (FloatType* pSumTargetsByBin = SumTargetsByBin)
            fixed (FloatType* pChildSumTargetsByBin = child.SumTargetsByBin)
            fixed (double* pSumWeightsByBin = SumWeightsByBin)
            fixed (double* pChildSumWeightsByBin = child.SumWeightsByBin)
            fixed (int* pTotalCountByBin = CountByBin)
            fixed (int* pChildTotalCountByBin = child.CountByBin)
            {
                if (pSumWeightsByBin == null)
                {
                    for (int i = 0; i < NumFeatureValues; i++)
                    {
                        pSumTargetsByBin[i] -= pChildSumTargetsByBin[i];
                        pTotalCountByBin[i] -= pChildTotalCountByBin[i];
                    }
                }
                else
                {
                    Contracts.Assert(pChildSumWeightsByBin != null);
                    for (int i = 0; i < NumFeatureValues; i++)
                    {
                        pSumTargetsByBin[i] -= pChildSumTargetsByBin[i];
                        pSumWeightsByBin[i] -= pChildSumWeightsByBin[i];
                        pTotalCountByBin[i] -= pChildTotalCountByBin[i];
                    }
                }
            }
        }
 
        public void Sumup(int numDocsInLeaf, double sumTargets, FloatType[] outputs, int[] docIndices)
        {
            SumupWeighted(numDocsInLeaf, sumTargets, 0.0, outputs, null, docIndices);
        }
 
        public void SumupWeighted(int numDocsInLeaf, double sumTargets, double sumWeights, FloatType[] outputs, double[] weights, int[] docIndices)
        {
            using (Timer.Time(TimerEvent.Sumup))
            {
#if TLC_REVISION
                Array.Clear(SumWeightedTargetsByBin, 0, SumWeightedTargetsByBin.Length);
#else
                Array.Clear(SumTargetsByBin, 0, SumTargetsByBin.Length);
#endif
 
                if (SumWeightsByBin != null)
                {
                    Array.Clear(SumWeightsByBin, 0, SumWeightsByBin.Length);
                }
 
                Array.Clear(CountByBin, 0, CountByBin.Length);
 
                if (numDocsInLeaf > 0)
                {
                    SumupInputData input = new SumupInputData(
                        numDocsInLeaf,
                        sumTargets,
                        sumWeights,
                        outputs,
                        weights,
                        docIndices);
 
                    _bins.Sumup(input, this);
                }
            }
        }
    }
 
    internal sealed class SumupInputData
    {
        public int TotalCount;
        public double SumTargets;
        public readonly FloatType[] Outputs;
        public readonly int[] DocIndices;
        public double SumWeights;
        public readonly double[] Weights;
 
        public SumupInputData(int totalCount, double sumTargets, double sumWeights,
            FloatType[] outputs, double[] weights, int[] docIndices)
        {
            TotalCount = totalCount;
            SumTargets = sumTargets;
            Outputs = outputs;
            DocIndices = docIndices;
            SumWeights = sumWeights;
            Weights = weights;
        }
    }
 
}