File: Utilities\SupervisedBinFinder.cs
Web Access
Project: src\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj (Microsoft.ML.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Internal.Utilities
{
    /// <summary>
    /// This class performs discretization of (value, label) pairs into bins in a way that minimizes
    /// the target function "minimum description length".
    /// The algorithm is outlined in an article
    /// "Multi-Interval Discretization of Continuous-Valued Attributes for Classification Learning"
    /// [Fayyad, Usama M.; Irani, Keki B. (1993)] https://ijcai.org/Past%20Proceedings/IJCAI-93-VOL2/PDF/022.pdf
    ///
    /// The class can be used several times sequentially, it is stateful and not thread-safe.
    /// Both Single and Double precision processing is implemented, and is identical.
    /// </summary>
    [BestFriend]
    internal sealed class SupervisedBinFinder
    {
        private readonly struct ValuePair<T> : IComparable<ValuePair<T>>
            where T : IComparable<T>
        {
            public readonly T Value;
            public readonly int Label;
 
            public ValuePair(T value, int label)
            {
                Value = value;
                Label = label;
            }
 
            public int CompareTo(ValuePair<T> other)
            {
                return Value.CompareTo(other.Value);
            }
        }
 
        private int _valueCount;
        private int _distinctValueCount;
        private int _labelCardinality;
        private int _maxBins;
        private int _minBinSize;
 
        // cumulative counts for distinct values. Dimensions: _distinctValueCount X (_labelCardinality + 1) (last column is the total counts)
        // REVIEW: optimize memory allocation in sequential use case (don't re-allocate if we have a large enough array already)
        private int[,] _cumulativeCounts;
 
        /// <summary>
        /// Finds the bins for Single values (and integer labels)
        /// </summary>
        /// <param name="maxBins">Maximum number of bins</param>
        /// <param name="minBinSize">Minimum number of values per bin (stopping condition for greedy bin splitting)</param>
        /// <param name="nLabels">Cardinality of the labels</param>
        /// <param name="values">The feature values</param>
        /// <param name="labels">The corresponding label values</param>
        /// <returns>An array of split points, no more than <paramref name="maxBins"/> total (but maybe less), ending with PositiveInfinity</returns>
        public Single[] FindBins(int maxBins, int minBinSize, int nLabels, IList<Single> values, IList<int> labels)
        {
            // prepare the values: count distinct values and populate the value pair array
            _valueCount = values.Count;
            _labelCardinality = nLabels;
            _maxBins = maxBins;
            _minBinSize = minBinSize;
            Contracts.Assert(_valueCount == labels.Count);
            _distinctValueCount = 0;
            var seenValues = new HashSet<Single>();
            var valuePairs = new ValuePair<Single>[_valueCount];
            for (int i = 0; i < _valueCount; i++)
            {
                valuePairs[i] = new ValuePair<Single>(values[i], labels[i]);
                if (seenValues.Add(values[i]))
                    _distinctValueCount++;
            }
            Array.Sort(valuePairs);
 
            // populate the cumulative counts with unique values
            _cumulativeCounts = new int[_distinctValueCount, _labelCardinality + 1];
            var distinctValues = new Single[_distinctValueCount];
            Single curValue = Single.NegativeInfinity;
            int curIndex = -1;
            foreach (var pair in valuePairs)
            {
                Contracts.Assert(pair.Value >= curValue);
                if (pair.Value > curValue || curIndex < 0)
                {
                    curValue = pair.Value;
                    curIndex++;
                    distinctValues[curIndex] = curValue;
                    if (curIndex > 0)
                    {
                        for (int i = 0; i < _labelCardinality + 1; i++)
                            _cumulativeCounts[curIndex, i] = _cumulativeCounts[curIndex - 1, i];
                    }
                }
                _cumulativeCounts[curIndex, pair.Label]++;
                _cumulativeCounts[curIndex, _labelCardinality]++;
            }
 
            Contracts.Assert(curIndex == _distinctValueCount - 1);
 
            var boundaries = FindBinsCore();
            Contracts.Assert(Utils.Size(boundaries) > 0);
            Contracts.Assert(boundaries.Length == 1 && boundaries[0] == 0 || boundaries[0] > 0, "boundaries are exclusive, can't have 0");
            Contracts.Assert(boundaries[boundaries.Length - 1] == _distinctValueCount);
 
            // transform boundary indices back into bin upper bounds
            var numUpperBounds = boundaries.Length;
            Single[] result = new Single[numUpperBounds];
            for (int i = 0; i < numUpperBounds - 1; i++)
            {
                var split = boundaries[i];
                result[i] = BinFinderBase.GetSplitValue(distinctValues[split - 1], distinctValues[split]);
 
                // Even though distinctValues may contain infinities, the boundaries may not be infinite:
                // GetSplitValue(a,b) only returns +-inf if a==b==+-inf,
                // and distinctValues won't contain more than one +inf or -inf.
                Contracts.Assert(FloatUtils.IsFinite(result[i]));
            }
 
            result[numUpperBounds - 1] = Single.PositiveInfinity;
            AssertStrictlyIncreasing(result);
 
            return result;
        }
 
        /// <summary>
        /// Finds the bins for Double values (and integer labels)
        /// </summary>
        /// <param name="maxBins">Maximum number of bins</param>
        /// <param name="minBinSize">Minimum number of values per bin (stopping condition for greedy bin splitting)</param>
        /// <param name="nLabels">Cardinality of the labels</param>
        /// <param name="values">The feature values</param>
        /// <param name="labels">The corresponding label values</param>
        /// <returns>An array of split points, no more than <paramref name="maxBins"/> total (but maybe less), ending with PositiveInfinity</returns>
        public Double[] FindBins(int maxBins, int minBinSize, int nLabels, IList<Double> values, IList<int> labels)
        {
            // prepare the values: count distinct values and populate the value pair array
            _valueCount = values.Count;
            _labelCardinality = nLabels;
            _maxBins = maxBins;
            _minBinSize = minBinSize;
            Contracts.Assert(_valueCount == labels.Count);
            _distinctValueCount = 0;
            var seenValues = new HashSet<Double>();
            var valuePairs = new ValuePair<Double>[_valueCount];
            for (int i = 0; i < _valueCount; i++)
            {
                valuePairs[i] = new ValuePair<Double>(values[i], labels[i]);
                if (seenValues.Add(values[i]))
                    _distinctValueCount++;
            }
            Array.Sort(valuePairs);
 
            // populate the cumulative counts with unique values
            _cumulativeCounts = new int[_distinctValueCount, _labelCardinality + 1];
            var distinctValues = new Double[_distinctValueCount];
            Double curValue = Double.NegativeInfinity;
            int curIndex = -1;
            foreach (var pair in valuePairs)
            {
                Contracts.Assert(pair.Value >= curValue);
                if (pair.Value > curValue || curIndex < 0)
                {
                    curValue = pair.Value;
                    curIndex++;
                    distinctValues[curIndex] = curValue;
                    if (curIndex > 0)
                    {
                        for (int i = 0; i < _labelCardinality + 1; i++)
                            _cumulativeCounts[curIndex, i] = _cumulativeCounts[curIndex - 1, i];
                    }
                }
                _cumulativeCounts[curIndex, pair.Label]++;
                _cumulativeCounts[curIndex, _labelCardinality]++;
            }
 
            Contracts.Assert(curIndex == _distinctValueCount - 1);
 
            var boundaries = FindBinsCore();
            Contracts.Assert(Utils.Size(boundaries) > 0);
            Contracts.Assert(boundaries.Length == 1 && boundaries[0] == 0 || boundaries[0] > 0, "boundaries are exclusive, can't have 0");
            Contracts.Assert(boundaries[boundaries.Length - 1] == _distinctValueCount);
 
            // transform boundary indices back into bin upper bounds
            var numUpperBounds = boundaries.Length;
            Double[] result = new Double[numUpperBounds];
            for (int i = 0; i < numUpperBounds - 1; i++)
            {
                var split = boundaries[i];
                result[i] = BinFinderBase.GetSplitValue(distinctValues[split - 1], distinctValues[split]);
 
                // Even though distinctValues may contain infinities, the boundaries may not be infinite:
                // GetSplitValue(a,b) only returns +-inf if a==b==+-inf,
                // and distinctValues won't contain more than one +inf or -inf.
                Contracts.Assert(FloatUtils.IsFinite(result[i]));
            }
 
            result[numUpperBounds - 1] = Double.PositiveInfinity;
            AssertStrictlyIncreasing(result);
 
            return result;
        }
 
        [Conditional("DEBUG")]
        private void AssertStrictlyIncreasing(Single[] result)
        {
#if DEBUG
            for (int i = 1; i < result.Length; i++)
                Contracts.Assert(result[i] > result[i - 1]);
#endif
        }
 
        [Conditional("DEBUG")]
        private void AssertStrictlyIncreasing(Double[] result)
        {
#if DEBUG
            for (int i = 1; i < result.Length; i++)
                Contracts.Assert(result[i] > result[i - 1]);
#endif
        }
 
        private class SplitInterval
        {
            public readonly int Min;
            public readonly int Lim;
 
            public readonly Double Gain;
            public readonly int SplitLim;
 
            public SplitInterval(SupervisedBinFinder binFinder, int min, int lim, bool skipSplitCalculation)
            {
                Min = min;
                Lim = lim;
                Gain = -1;
 
                if (skipSplitCalculation)
                    return;// no split is done
 
                // calculate best split and associated gain
                int totalCount;
                Double totalEntropy = binFinder.GetEntropy(min, lim, out totalCount);
                if (totalCount < binFinder._minBinSize) // too small bin, won't split
                    return;
                if (totalEntropy <= 0) // we achieved perfect entropy, no need to split any further
                    return;
 
                Double logN = Math.Log(lim - min);
                for (int split = min + 1; split < lim; split++)
                {
                    int leftCount;
                    int rightCount;
                    var leftEntropy = binFinder.GetEntropy(min, split, out leftCount);
                    var rightEntropy = binFinder.GetEntropy(split, lim, out rightCount);
                    Contracts.Assert(leftCount + rightCount == totalCount);
 
                    // This term corresponds to the 'fixed cost associated with a split'
                    // It's a simplification of a Delta(A,T;S) term calculated in the paper
                    var delta = logN - binFinder._labelCardinality * (totalEntropy - leftEntropy - rightEntropy);
 
                    var curGain = totalCount * totalEntropy // total cost of transmitting non-split content
                               - leftCount * leftEntropy // cost of transmitting left part of the split
                               - rightCount * rightEntropy // cost of transmitting right part of the split
                               - delta; // fixed cost of transmitting additional codebook
                    if (curGain > Gain)
                    {
                        Gain = curGain;
                        SplitLim = split;
                    }
                }
            }
        }
 
        /// <summary>
        /// Calculate the entropy and label cardinality for a given interval within the data
        /// </summary>
        private Double GetEntropy(int min, int lim, out int totalCount)
        {
            Double entropy = 0;
            totalCount = _cumulativeCounts[lim - 1, _labelCardinality];
            if (min > 0)
                totalCount -= _cumulativeCounts[min - 1, _labelCardinality];
            for (int i = 0; i < _labelCardinality; i++)
            {
                var count = _cumulativeCounts[lim - 1, i];
                if (min > 0)
                    count -= _cumulativeCounts[min - 1, i];
                if (count == 0 || count == totalCount)
                    continue;
                var p = (Double)count / totalCount;
                entropy -= p * Math.Log(p);
            }
 
            return entropy;
        }
 
        /// <summary>
        /// Finds the optimum bins with respect to <see cref="_cumulativeCounts"/>
        /// </summary>
        /// <returns>The sorted array of indices that are exclusive upper bounds of the respective bins</returns>
        private int[] FindBinsCore()
        {
            if (_distinctValueCount == 0)
                return new int[] { _distinctValueCount };
 
            // we will put intervals into a heap so that the one with maximum gain is at the top
            var intervals = new Heap<SplitInterval>((x, y) => x.Gain < y.Gain);
 
            // start with a single interval covering all points
            intervals.Add(new SplitInterval(this, 0, _distinctValueCount, false));
 
            // while we haven't reached max # of bins and there's still gain in splitting (best interval's gain is positive)
            while (intervals.Count < _maxBins && intervals.Top.Gain > 0)
            {
                // take the interval with the best split gain
                var toSplit = intervals.Pop();
 
                // make the split
                bool isLastSplit = intervals.Count == _maxBins - 1;
                var left = new SplitInterval(this, toSplit.Min, toSplit.SplitLim, isLastSplit);
                var right = new SplitInterval(this, toSplit.SplitLim, toSplit.Lim, isLastSplit);
 
                // put the results back into the heap
                intervals.Add(left);
                intervals.Add(right);
            }
 
            var binCount = intervals.Count;
            var results = new int[binCount];
            for (int i = 0; i < binCount; i++)
                results[i] = intervals.Pop().Lim;
 
            Contracts.Assert(intervals.Count == 0);
 
            Array.Sort(results);
            return results;
        }
    }
}