File: Dataset\SingletonFeatureFlock.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Linq;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    /// <summary>
    /// The singleton feature flock is the simplest possible sort of flock, that is, a flock
    /// over one feature.
    /// </summary>
    internal sealed class SingletonFeatureFlock : FeatureFlockBase
    {
        private readonly IntArray _bins;
        private readonly double[] _binUpperBounds;
 
        public override int Examples => _bins.Length;
 
        public SingletonFeatureFlock(IntArray bins, double[] binUpperBounds)
            : base(1)
        {
            Contracts.AssertValue(bins);
            Contracts.AssertValue(binUpperBounds);
            Contracts.Assert(bins.Length == 0 || bins.Max() < binUpperBounds.Length);
 
            _bins = bins;
            _binUpperBounds = binUpperBounds;
        }
 
        public override long SizeInBytes()
        {
            return _bins.SizeInBytes() + sizeof(double) * _binUpperBounds.Length;
        }
 
        internal override SufficientStatsBase CreateSufficientStats(bool hasWeights)
        {
            return new SufficientStats(this, hasWeights);
        }
 
        public override IIntArrayForwardIndexer GetIndexer(int featureIndex)
        {
            Contracts.Assert(featureIndex == 0);
            return _bins.GetIndexer();
        }
 
        public override int BinCount(int featureIndex)
        {
            Contracts.Assert(featureIndex == 0);
            return _binUpperBounds.Length;
        }
 
        public override FlockForwardIndexerBase GetFlockIndexer()
        {
            return new Indexer(this);
        }
 
        public override FeatureFlockBase[] Split(int[][] assignment)
        {
            return _bins.Split(assignment)
                .Select(bins => new SingletonFeatureFlock(bins, _binUpperBounds)).ToArray();
        }
 
        public override double Trust(int featureIndex)
        {
            Contracts.Assert(featureIndex == 0);
            return 1;
        }
 
        public override double[] BinUpperBounds(int featureIndex)
        {
            Contracts.Assert(featureIndex == 0);
            return _binUpperBounds;
        }
 
        private sealed class Indexer : FlockForwardIndexerBase
        {
            private readonly SingletonFeatureFlock _flock;
            private readonly IIntArrayForwardIndexer _indexer;
 
            public override FeatureFlockBase Flock { get { return _flock; } }
 
            public override int this[int featureIndex, int rowIndex]
            {
                get
                {
                    Contracts.Assert(featureIndex == 0);
                    return _indexer[rowIndex];
                }
            }
 
            public Indexer(SingletonFeatureFlock flock)
            {
                Contracts.AssertValue(flock);
                _flock = flock;
                _indexer = _flock.GetIndexer(0);
            }
        }
 
        private sealed class SufficientStats : SufficientStatsBase<SufficientStats>
        {
            private readonly SingletonFeatureFlock _flock;
            private readonly FeatureHistogram _hist;
 
            public override FeatureFlockBase Flock
            {
                get { return _flock; }
            }
 
            public SufficientStats(SingletonFeatureFlock flock, bool hasWeights)
                : base(flock.Count)
            {
                Contracts.AssertValue(flock);
                _flock = flock;
                _hist = new FeatureHistogram(_flock._bins, _flock._binUpperBounds.Length, hasWeights);
            }
 
            protected override void SubtractCore(SufficientStats other)
            {
                _hist.Subtract(other._hist);
            }
 
            protected override void SumupCore(int featureOffset, bool[] active,
                int numDocsInLeaf, double sumTargets, double sumWeights,
                double[] outputs, double[] weights, int[] docIndices)
            {
                Contracts.AssertValueOrNull(active);
                Contracts.Assert(active == null || (0 <= featureOffset && featureOffset <= Utils.Size(active) - Flock.Count));
                if (active != null && !active[featureOffset])
                    return;
                _hist.SumupWeighted(numDocsInLeaf, sumTargets, sumWeights, outputs, weights, docIndices);
            }
 
            public override long SizeInBytes()
            {
                return FeatureHistogram.EstimateMemoryUsedForFeatureHistogram(_hist.NumFeatureValues,
                    _hist.SumWeightsByBin != null);
            }
 
            protected override int GetMaxBorder(int featureIndex)
            {
                return _hist.NumFeatureValues - 1;
            }
 
            protected override int GetMinBorder(int featureIndex)
            {
                return 1;
            }
 
            protected override PerBinStats GetBinStats(int featureIndex)
            {
                if (_hist.SumWeightsByBin != null)
                    return new PerBinStats(_hist.SumTargetsByBin[featureIndex], _hist.SumWeightsByBin[featureIndex], _hist.CountByBin[featureIndex]);
                else
                    return new PerBinStats(_hist.SumTargetsByBin[featureIndex], 0, _hist.CountByBin[featureIndex]);
            }
 
            protected override double GetBinGradient(int featureIndex, double bias)
            {
                if (_hist.SumWeightsByBin != null)
                    return _hist.SumTargetsByBin[featureIndex] / (_hist.SumWeightsByBin[featureIndex] + bias);
                else
                    return _hist.SumTargetsByBin[featureIndex] / (_hist.CountByBin[featureIndex] + bias);
            }
        }
    }
}