File: RandomForest.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    public abstract class RandomForestTrainerBase<TOptions, TTransformer, TModel> : FastTreeTrainerBase<TOptions, TTransformer, TModel>
        where TOptions : FastForestOptionsBase, new()
        where TModel : class
        where TTransformer : ISingleFeaturePredictionTransformer<TModel>
    {
        private readonly bool _quantileEnabled;
 
        /// <summary>
        /// Constructor invoked by the maml code-path.
        /// </summary>
        private protected RandomForestTrainerBase(IHostEnvironment env, TOptions options, SchemaShape.Column label, bool quantileEnabled = false)
            : base(env, options, label)
        {
            _quantileEnabled = quantileEnabled;
        }
 
        /// <summary>
        /// Constructor invoked by the API code-path.
        /// </summary>
        private protected RandomForestTrainerBase(IHostEnvironment env,
            SchemaShape.Column label,
            string featureColumnName,
            string exampleWeightColumnName,
            string rowGroupColumnName,
            int numberOfLeaves,
            int numberOfTrees,
            int minimumExampleCountPerLeaf,
            bool quantileEnabled = false)
            : base(env, label, featureColumnName, exampleWeightColumnName, null, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf)
        {
            _quantileEnabled = quantileEnabled;
        }
 
        private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
        {
            Host.CheckValue(ch, nameof(ch));
            IGradientAdjuster gradientWrapper = MakeGradientWrapper(ch);
            var optimizationAlgorithm = new RandomForestOptimizer(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
 
            optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch);
            optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch);
            optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing;
            // No notion of dropout for non-boosting applications.
            optimizationAlgorithm.DropoutRate = 0;
            optimizationAlgorithm.DropoutRng = null;
            optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph;
 
            return optimizationAlgorithm;
        }
 
        private protected override void InitializeTests()
        {
        }
 
        private protected override TreeLearner ConstructTreeLearner(IChannel ch)
        {
            return new RandomForestLeastSquaresTreeLearner(
                       TrainSet, FastTreeTrainerOptions.NumberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.EntropyCoefficient,
                       FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature,
                       FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit,
                       FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode,
                       FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, _quantileEnabled, FastTreeTrainerOptions.NumberOfQuantileSamples, ParallelTraining,
                       FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit,
                       FastTreeTrainerOptions.Bias, Host);
        }
 
        internal abstract class RandomForestObjectiveFunction : ObjectiveFunctionBase
        {
            protected RandomForestObjectiveFunction(Dataset trainData, TOptions options, double maxStepSize)
                : base(trainData,
                    1, // No learning rate in random forests.
                    1, // No shrinkage in random forests.
                    maxStepSize,
                    1, // No derivative sampling in random forests.
                    false, // Improvements to quasi-newton step not relevant to RF.
                    options.Seed)
            {
            }
        }
    }
}