File: BoostingFastTree.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    public abstract class BoostingFastTreeTrainerBase<TOptions, TTransformer, TModel> : FastTreeTrainerBase<TOptions, TTransformer, TModel>
        where TTransformer : ISingleFeaturePredictionTransformer<TModel>
        where TOptions : BoostedTreeOptions, new()
        where TModel : class
    {
        private protected BoostingFastTreeTrainerBase(IHostEnvironment env, TOptions options, SchemaShape.Column label) : base(env, options, label)
        {
        }
 
        private protected BoostingFastTreeTrainerBase(IHostEnvironment env,
            SchemaShape.Column label,
            string featureColumnName,
            string exampleWeightColumnName,
            string rowGroupColumnName,
            int numberOfLeaves,
            int numberOfTrees,
            int minimumExampleCountPerLeaf,
            double learningRate)
            : base(env, label, featureColumnName, exampleWeightColumnName, rowGroupColumnName, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf)
        {
            FastTreeTrainerOptions.LearningRate = learningRate;
        }
 
        private protected override void CheckOptions(IChannel ch)
        {
            if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeOptions.OptimizationAlgorithmType.AcceleratedGradientDescent)
                FastTreeTrainerOptions.UseLineSearch = true;
            if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeOptions.OptimizationAlgorithmType.ConjugateGradientDescent)
                FastTreeTrainerOptions.UseLineSearch = true;
 
            if (FastTreeTrainerOptions.CompressEnsemble && FastTreeTrainerOptions.WriteLastEnsemble)
                throw ch.Except("Ensemble compression cannot be done when forcing to write last ensemble (hl)");
 
            if (FastTreeTrainerOptions.NumberOfLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumberOfLeaves - 1)
                throw ch.Except("Histogram pool size (ps) must be at least 2.");
 
            if (FastTreeTrainerOptions.NumberOfLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumberOfLeaves - 1)
                throw ch.Except("Histogram pool size (ps) must be at most numLeaves - 1.");
 
            if (FastTreeTrainerOptions.EnablePruning && !HasValidSet)
                throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid).");
 
            bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null;
            if (doEarlyStop && !HasValidSet)
                throw ch.Except("Cannot perform early stopping without a validation set (valid).");
 
            if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet))
                throw ch.Except("Cannot perform tolerant pruning (prtol) without pruning (pruning) and a validation set (valid)");
 
            base.CheckOptions(ch);
        }
 
        private protected override TreeLearner ConstructTreeLearner(IChannel ch)
        {
            return new LeastSquaresRegressionTreeLearner(
                TrainSet, FastTreeTrainerOptions.NumberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.EntropyCoefficient,
                FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature,
                FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit, FastTreeTrainerOptions.FilterZeroLambdas,
                FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode,
                FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, BsrMaxTreeOutput(), ParallelTraining,
                FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit,
                FastTreeTrainerOptions.Bias, Host);
        }
 
        private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
        {
            Contracts.CheckValue(ch, nameof(ch));
            OptimizationAlgorithm optimizationAlgorithm;
            IGradientAdjuster gradientWrapper = MakeGradientWrapper(ch);
 
            switch (FastTreeTrainerOptions.OptimizationAlgorithm)
            {
                case BoostedTreeOptions.OptimizationAlgorithmType.GradientDescent:
                    optimizationAlgorithm = new GradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
                    break;
                case BoostedTreeOptions.OptimizationAlgorithmType.AcceleratedGradientDescent:
                    optimizationAlgorithm = new AcceleratedGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
                    break;
                case BoostedTreeOptions.OptimizationAlgorithmType.ConjugateGradientDescent:
                    optimizationAlgorithm = new ConjugateGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
                    break;
                default:
                    throw ch.Except("Unknown optimization algorithm '{0}'", FastTreeTrainerOptions.OptimizationAlgorithm);
            }
 
            optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch);
            optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch);
            optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing;
            optimizationAlgorithm.DropoutRate = FastTreeTrainerOptions.DropoutRate;
            optimizationAlgorithm.DropoutRng = new Random(FastTreeTrainerOptions.Seed);
            optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph;
 
            return optimizationAlgorithm;
        }
 
        private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch)
        {
            if (!FastTreeTrainerOptions.BestStepRankingRegressionTrees)
                return base.MakeGradientWrapper(ch);
 
            // REVIEW: If this is ranking specific than cmd.bestStepRankingRegressionTrees and
            // this code should be part of Ranking application (and not application).
            if (AreSamplesWeighted(ch))
                return new QueryWeightsBestResressionStepGradientWrapper();
            else
                return new BestStepRegressionGradientWrapper();
        }
 
        private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStoppingRule, ref int bestIteration)
        {
            if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null)
                return false;
 
            ch.AssertValue(ValidTest);
            ch.AssertValue(TrainTest);
 
            var validationResult = ValidTest.ComputeTests().First();
            ch.Assert(validationResult.FinalValue >= 0);
            bool lowerIsBetter = validationResult.LowerIsBetter;
 
            var trainingResult = TrainTest.ComputeTests().First();
            ch.Assert(trainingResult.FinalValue >= 0);
 
            // Create early stopping rule if it's null.
            if (earlyStoppingRule == null)
            {
                if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null)
                    earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter);
            }
 
            // Early stopping rule cannot be null!
            ch.Assert(earlyStoppingRule != null);
 
            bool isBestCandidate;
            bool shouldStop = earlyStoppingRule.CheckScore((float)validationResult.FinalValue,
                (float)trainingResult.FinalValue, out isBestCandidate);
 
            if (isBestCandidate)
                bestIteration = Ensemble.NumTrees;
 
            return shouldStop;
        }
 
        private protected override int GetBestIteration(IChannel ch)
        {
            int bestIteration = Ensemble.NumTrees;
            if (!FastTreeTrainerOptions.WriteLastEnsemble && PruningTest != null)
            {
                bestIteration = PruningTest.BestIteration;
                ch.Info("Pruning picked iteration {0}", bestIteration);
            }
            return bestIteration;
        }
 
        /// <summary>
        /// Retrieves max tree output if best regression step option is active or returns negative value otherwise.
        /// </summary>
        private protected double BsrMaxTreeOutput()
        {
            if (FastTreeTrainerOptions.BestStepRankingRegressionTrees)
                return FastTreeTrainerOptions.MaximumTreeOutput;
            else
                return -1;
        }
 
        private protected override bool ShouldRandomStartOptimizer()
        {
            return FastTreeTrainerOptions.RandomStart;
        }
    }
}