File: Dataset\SingletonFeatureFlock.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Linq;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Trainers.FastTree
    /// <summary>
    /// The singleton feature flock is the simplest possible sort of flock, that is, a flock
    /// over one feature.
    /// </summary>
    internal sealed class SingletonFeatureFlock : FeatureFlockBase
        private readonly IntArray _bins;
        private readonly double[] _binUpperBounds;
        public override int Examples => _bins.Length;
        public SingletonFeatureFlock(IntArray bins, double[] binUpperBounds)
            : base(1)
            Contracts.Assert(bins.Length == 0 || bins.Max() < binUpperBounds.Length);
            _bins = bins;
            _binUpperBounds = binUpperBounds;
        public override long SizeInBytes()
            return _bins.SizeInBytes() + sizeof(double) * _binUpperBounds.Length;
        internal override SufficientStatsBase CreateSufficientStats(bool hasWeights)
            return new SufficientStats(this, hasWeights);
        public override IIntArrayForwardIndexer GetIndexer(int featureIndex)
            Contracts.Assert(featureIndex == 0);
            return _bins.GetIndexer();
        public override int BinCount(int featureIndex)
            Contracts.Assert(featureIndex == 0);
            return _binUpperBounds.Length;
        public override FlockForwardIndexerBase GetFlockIndexer()
            return new Indexer(this);
        public override FeatureFlockBase[] Split(int[][] assignment)
            return _bins.Split(assignment)
                .Select(bins => new SingletonFeatureFlock(bins, _binUpperBounds)).ToArray();
        public override double Trust(int featureIndex)
            Contracts.Assert(featureIndex == 0);
            return 1;
        public override double[] BinUpperBounds(int featureIndex)
            Contracts.Assert(featureIndex == 0);
            return _binUpperBounds;
        private sealed class Indexer : FlockForwardIndexerBase
            private readonly SingletonFeatureFlock _flock;
            private readonly IIntArrayForwardIndexer _indexer;
            public override FeatureFlockBase Flock { get { return _flock; } }
            public override int this[int featureIndex, int rowIndex]
                    Contracts.Assert(featureIndex == 0);
                    return _indexer[rowIndex];
            public Indexer(SingletonFeatureFlock flock)
                _flock = flock;
                _indexer = _flock.GetIndexer(0);
        private sealed class SufficientStats : SufficientStatsBase<SufficientStats>
            private readonly SingletonFeatureFlock _flock;
            private readonly FeatureHistogram _hist;
            public override FeatureFlockBase Flock
                get { return _flock; }
            public SufficientStats(SingletonFeatureFlock flock, bool hasWeights)
                : base(flock.Count)
                _flock = flock;
                _hist = new FeatureHistogram(_flock._bins, _flock._binUpperBounds.Length, hasWeights);
            protected override void SubtractCore(SufficientStats other)
            protected override void SumupCore(int featureOffset, bool[] active,
                int numDocsInLeaf, double sumTargets, double sumWeights,
                double[] outputs, double[] weights, int[] docIndices)
                Contracts.Assert(active == null || (0 <= featureOffset && featureOffset <= Utils.Size(active) - Flock.Count));
                if (active != null && !active[featureOffset])
                _hist.SumupWeighted(numDocsInLeaf, sumTargets, sumWeights, outputs, weights, docIndices);
            public override long SizeInBytes()
                return FeatureHistogram.EstimateMemoryUsedForFeatureHistogram(_hist.NumFeatureValues,
                    _hist.SumWeightsByBin != null);
            protected override int GetMaxBorder(int featureIndex)
                return _hist.NumFeatureValues - 1;
            protected override int GetMinBorder(int featureIndex)
                return 1;
            protected override PerBinStats GetBinStats(int featureIndex)
                if (_hist.SumWeightsByBin != null)
                    return new PerBinStats(_hist.SumTargetsByBin[featureIndex], _hist.SumWeightsByBin[featureIndex], _hist.CountByBin[featureIndex]);
                    return new PerBinStats(_hist.SumTargetsByBin[featureIndex], 0, _hist.CountByBin[featureIndex]);
            protected override double GetBinGradient(int featureIndex, double bias)
                if (_hist.SumWeightsByBin != null)
                    return _hist.SumTargetsByBin[featureIndex] / (_hist.SumWeightsByBin[featureIndex] + bias);
                    return _hist.SumTargetsByBin[featureIndex] / (_hist.CountByBin[featureIndex] + bias);