File: FastTreeArguments.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
 
[assembly: EntryPointModule(typeof(FastTreeBinaryTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeTweedieTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeRankingTrainer.Options))]
 
namespace Microsoft.ML.Trainers.FastTree
{
    [TlcModule.ComponentKind("FastTreeTrainer")]
    internal interface IFastTreeTrainerFactory : IComponentFactory<ITrainer>
    {
    }
 
    /// <summary>
    /// Stopping measurements for classification and regression.
    /// </summary>
    public enum EarlyStoppingMetric
    {
        /// <summary>
        /// L1-norm of gradient.
        /// </summary>
        L1Norm = 1,
        /// <summary>
        /// L2-norm of gradient.
        /// </summary>
        L2Norm = 2
    };
 
    /// <summary>
    /// Stopping measurements for ranking.
    /// </summary>
    public enum EarlyStoppingRankingMetric
    {
        /// <summary>
        /// NDCG@1
        /// </summary>
        NdcgAt1 = 1,
        /// <summary>
        /// NDCG@3
        /// </summary>
        NdcgAt3 = 3
    }
 
    // XML docs are provided in the other part of this partial class. No need to duplicate the content here.
    public sealed partial class FastTreeBinaryTrainer
    {
        /// <summary>
        /// Options for the <see cref="FastTreeBinaryTrainer"/> as used in
        /// [FastTree(Options)](xref:"Microsoft.ML.TreeExtensions.FastTree(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer.Options)).
        /// </summary>
        [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
        public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
        {
 
            /// <summary>
            /// Whether to use derivatives optimized for unbalanced training data.
            /// </summary>
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Option for using derivatives optimized for unbalanced sets", ShortName = "us")]
            [TGUI(Label = "Optimize for unbalanced")]
            public bool UnbalancedSets = false;
 
            /// <summary>
            /// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
            /// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
            /// </summary>
            // Disable 649 because Visual Studio can't detect its assignment via property.
#pragma warning disable 649
            private EarlyStoppingMetric _earlyStoppingMetric;
#pragma warning restore 649
 
            /// <summary>
            /// Early stopping metrics.
            /// </summary>
            public EarlyStoppingMetric EarlyStoppingMetric
            {
                get { return _earlyStoppingMetric; }
 
                set
                {
                    // Update the state of the user-facing stopping metric.
                    _earlyStoppingMetric = value;
                    // Set up internal property according to its public value.
                    EarlyStoppingMetrics = (int)_earlyStoppingMetric;
                }
            }
 
            /// <summary>
            /// Create a new <see cref="Options"/> object with default values.
            /// </summary>
            public Options()
            {
                // Use L1 by default.
                EarlyStoppingMetric = EarlyStoppingMetric.L1Norm;
            }
 
            ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeBinaryTrainer(env, this);
        }
    }
 
    // XML docs are provided in the other part of this partial class. No need to duplicate the content here.
    public sealed partial class FastTreeRegressionTrainer
    {
        /// <summary>
        /// Options for the <see cref="FastTreeRegressionTrainer"/> as used in
        /// [FastTree(Options)](xref:Microsoft.ML.TreeExtensions.FastTree(Microsoft.ML.RegressionCatalog.RegressionTrainers,Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer.Options)).
        /// </summary>
        [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
        public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
        {
            /// <summary>
            /// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
            /// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
            /// </summary>
            private EarlyStoppingMetric _earlyStoppingMetric;
 
            /// <summary>
            /// Early stopping metrics.
            /// </summary>
            public EarlyStoppingMetric EarlyStoppingMetric
            {
                get { return _earlyStoppingMetric; }
 
                set
                {
                    // Update the state of the user-facing stopping metric.
                    _earlyStoppingMetric = value;
                    // Set up internal property according to its public value.
                    EarlyStoppingMetrics = (int)_earlyStoppingMetric;
                }
            }
 
            /// <summary>
            /// Create a new <see cref="Options"/> object with default values.
            /// </summary>
            public Options()
            {
                EarlyStoppingMetric = EarlyStoppingMetric.L1Norm; // Use L1 by default.
            }
 
            ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeRegressionTrainer(env, this);
        }
    }
 
    // XML docs are provided in the other part of this partial class. No need to duplicate the content here.
    public sealed partial class FastTreeTweedieTrainer
    {
        /// <summary>
        /// Options for the <see cref="FastTreeTweedieTrainer"/> as used in
        /// [FastTreeTweedie(Options)](xref:Microsoft.ML.TreeExtensions.FastTreeTweedie(Microsoft.ML.RegressionCatalog.RegressionTrainers,Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer.Options)).
        /// </summary>
        [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
        public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
        {
            // REVIEW: It is possible to estimate this index parameter from the distribution of data, using
            // a combination of univariate optimization and grid search, following section 4.2 of the paper. However
            // it is probably not worth doing unless and until explicitly asked for.
            /// <summary>
            /// The index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss,
            /// and intermediate values are compound Poisson loss.
            /// </summary>
            [Argument(ArgumentType.LastOccurrenceWins, HelpText =
                "Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, " +
                "and intermediate values are compound Poisson loss.")]
            public Double Index = 1.5;
 
            /// <summary>
            /// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
            /// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
            /// </summary>
            // Disable 649 because Visual Studio can't detect its assignment via property.
#pragma warning disable 649
            private EarlyStoppingMetric _earlyStoppingMetric;
#pragma warning restore 649
 
            /// <summary>
            /// Early stopping metrics.
            /// </summary>
            public EarlyStoppingMetric EarlyStoppingMetric
            {
                get { return _earlyStoppingMetric; }
 
                set
                {
                    // Update the state of the user-facing stopping metric.
                    _earlyStoppingMetric = value;
                    // Set up internal property according to its public value.
                    EarlyStoppingMetrics = (int)_earlyStoppingMetric;
                }
            }
 
            /// <summary>
            /// Create a new <see cref="Options"/> object with default values.
            /// </summary>
            public Options()
            {
                EarlyStoppingMetric = EarlyStoppingMetric.L1Norm; // Use L1 by default.
            }
 
            ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeTweedieTrainer(env, this);
        }
    }
 
    // XML docs are provided in the other part of this partial class. No need to duplicate the content here.
    public sealed partial class FastTreeRankingTrainer
    {
        /// <summary>
        /// Options for the <see cref="FastTreeRankingTrainer"/> as used in
        /// [FastTree(Options)](xref:Microsoft.ML.TreeExtensions.FastTree(Microsoft.ML.RankingCatalog.RankingTrainers,Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer.Options)).
        /// </summary>
        [TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
        public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
        {
            /// <summary>
            /// Comma-separated list of gains associated with each relevance label.
            /// </summary>
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Comma-separated list of gains associated to each relevance label.", ShortName = "gains")]
            [TGUI(NoSweep = true)]
            public double[] CustomGains = new double[] { 0, 3, 7, 15, 31 };
 
            /// <summary>
            /// Whether to train using discounted cumulative gain (DCG) instead of normalized DCG (NDCG).
            /// </summary>
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Train DCG instead of NDCG", ShortName = "dcg")]
            public bool UseDcg;
 
            // REVIEW: Hiding sorting for now. Should be an enum or component factory.
            [BestFriend]
            [Argument(ArgumentType.LastOccurrenceWins,
                HelpText = "The sorting algorithm to use for DCG and LambdaMart calculations [DescendingStablePessimistic/DescendingStable/DescendingReverse/DescendingDotNet]",
                ShortName = "sort",
                Hide = true)]
            [TGUI(NotGui = true)]
            internal string SortingAlgorithm = "DescendingStablePessimistic";
 
            /// <summary>
            /// The maximum NDCG truncation to use in the
            /// <a href="https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf">LambdaMAR algorithm</a>.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "max-NDCG truncation to use in the LambdaMART algorithm", ShortName = "n", Hide = true)]
            [TGUI(NotGui = true)]
            public int NdcgTruncationLevel = 100;
 
            [BestFriend]
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Use shifted NDCG", Hide = true)]
            [TGUI(NotGui = true)]
            internal bool ShiftedNdcg;
 
            [BestFriend]
            [Argument(ArgumentType.AtMostOnce, HelpText = "Cost function parameter (w/c)", ShortName = "cf", Hide = true)]
            [TGUI(NotGui = true)]
            internal char CostFunctionParam = 'w';
 
            [BestFriend]
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Distance weight 2 adjustment to cost", ShortName = "dw", Hide = true)]
            [TGUI(NotGui = true)]
            internal bool DistanceWeight2;
 
            [BestFriend]
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Normalize query lambdas", ShortName = "nql", Hide = true)]
            [TGUI(NotGui = true)]
            internal bool NormalizeQueryLambdas;
 
            /// <summary>
            /// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
            /// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
            /// </summary>
            // Disable 649 because Visual Studio can't detect its assignment via property.
#pragma warning disable 649
            private EarlyStoppingRankingMetric _earlyStoppingMetric;
#pragma warning restore 649
 
            /// <summary>
            /// Early stopping metrics.
            /// </summary>
            public EarlyStoppingRankingMetric EarlyStoppingMetric
            {
                get { return _earlyStoppingMetric; }
 
                set
                {
                    // Update the state of the user-facing stopping metric.
                    _earlyStoppingMetric = value;
                    // Set up internal property according to its public value.
                    EarlyStoppingMetrics = (int)_earlyStoppingMetric;
                }
            }
 
            /// <summary>
            /// Create a new <see cref="Options"/> object with default values.
            /// </summary>
            public Options()
            {
                EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt1; // Use L1 by default.
                RowGroupColumnName = DefaultColumnNames.GroupId; // Use GroupId as default for ranking options.
            }
 
            ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this);
 
            internal override void Check(IExceptionContext ectx)
            {
                base.Check(ectx);
 
                ectx.CheckUserArg(SortingAlgorithm == "DescendingStable"
                    || SortingAlgorithm == "DescendingReverse"
                    || SortingAlgorithm == "DescendingDotNet"
                    || SortingAlgorithm == "DescendingStablePessimistic",
                    nameof(SortingAlgorithm),
                        "The specified sorting algorithm is invalid. Only 'DescendingStable', 'DescendingReverse', " +
                        "'DescendingDotNet', and 'DescendingStablePessimistic' are supported.");
#if OLD_DATALOAD
                ectx.CheckUserArg(0 <= secondaryMetricShare && secondaryMetricShare <= 1, "secondaryMetricShare", "secondaryMetricShare must be between 0 and 1.");
#endif
                ectx.CheckUserArg(0 < NdcgTruncationLevel, nameof(NdcgTruncationLevel), "must be positive.");
            }
        }
    }
 
    public enum Bundle : byte
    {
        None = 0,
        AggregateLowPopulation = 1,
        Adjacent = 2
    }
 
    [BestFriend]
    internal static class Defaults
    {
        public const int NumberOfTrees = 100;
        public const int NumberOfLeaves = 20;
        public const int MinimumExampleCountPerLeaf = 10;
        public const double LearningRate = 0.2;
    }
 
    /// <summary>
    /// Options for tree trainers.
    /// </summary>
    public abstract class TreeOptions : TrainerInputBaseWithGroupId
    {
        /// <summary>
        /// Allows to choose Parallel FastTree Learning Algorithm.
        /// </summary>
        [Argument(ArgumentType.Multiple, HelpText = "Allows to choose Parallel FastTree Learning Algorithm", ShortName = "parag")]
        internal ISupportParallelTraining ParallelTrainer = new SingleTrainerFactory();
 
        /// <summary>
        /// The number of threads to use.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The number of threads to use", ShortName = "t", NullName = "<Auto>")]
        public int? NumberOfThreads = null;
 
        // this random seed is used for:
        // 1. example sampling for feature binning
        // 2. init Randomize Score
        // 3. grad Sampling Rate in Objective Function
        // 4. tree learner
        // 5. bagging provider
        // 6. ensemble compressor
        /// <summary>
        /// The seed of the random number generator.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The seed of the random number generator", ShortName = "r1")]
        public int Seed = 123;
 
        // this random seed is only for active feature selection
        /// <summary>
        /// The seed of the active feature selection.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The seed of the active feature selection", ShortName = "r3", Hide = true)]
        [TGUI(NotGui = true)]
        public int FeatureSelectionSeed = 123;
 
        /// <summary>
        /// The entropy (regularization) coefficient between 0 and 1.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The entropy (regularization) coefficient between 0 and 1", ShortName = "e")]
        public Double EntropyCoefficient;
 
        // REVIEW: Different short name from TLC FR arguments.
        /// <summary>
        /// The number of histograms in the pool (between 2 and numLeaves).
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The number of histograms in the pool (between 2 and numLeaves)", ShortName = "ps")]
        public int HistogramPoolSize = -1;
 
        /// <summary>
        /// Whether to utilize the disk or the data's native transposition facilities (where applicable) when performing the transpose.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether to utilize the disk or the data's native transposition facilities (where applicable) when performing the transpose", ShortName = "dt")]
        public bool? DiskTranspose;
 
        /// <summary>
        /// Whether to collectivize features during dataset preparation to speed up training.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether to collectivize features during dataset preparation to speed up training", ShortName = "flocks", Hide = true)]
        public bool FeatureFlocks = true;
 
        /// <summary>
        /// Whether to do split based on multiple categorical feature values.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether to do split based on multiple categorical feature values.", ShortName = "cat")]
        public bool CategoricalSplit = false;
 
        /// <summary>
        /// Maximum categorical split groups to consider when splitting on a categorical feature. Split groups are a collection of split points. This is used to reduce overfitting when there many categorical features.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Maximum categorical split groups to consider when splitting on a categorical feature. " +
                                                             "Split groups are a collection of split points. This is used to reduce overfitting when " +
                                                             "there many categorical features.", ShortName = "mcg")]
        public int MaximumCategoricalGroupCountPerNode = 64;
 
        /// <summary>
        /// Maximum categorical split points to consider when splitting on a categorical feature.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Maximum categorical split points to consider when splitting on a categorical feature.", ShortName = "maxcat")]
        public int MaximumCategoricalSplitPointCount = 64;
 
        /// <summary>
        /// Minimum categorical example percentage in a bin to consider for a split. Default is 0.1% of all training examples.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Minimum categorical example percentage in a bin to consider for a split.", ShortName = "mdop")]
        public double MinimumExampleFractionForCategoricalSplit = 0.001;
 
        /// <summary>
        /// Minimum categorical example count in a bin to consider for a split.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Minimum categorical example count in a bin to consider for a split.", ShortName = "mdo")]
        public int MinimumExamplesForCategoricalSplit = 100;
 
        /// <summary>
        /// Bias for calculating gradient for each feature bin for a categorical feature.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Bias for calculating gradient for each feature bin for a categorical feature.", ShortName = "bias")]
        public double Bias = 0;
 
        /// <summary>
        /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Bundle low population bins. " +
                                                             "Bundle.None(0): no bundling, " +
                                                             "Bundle.AggregateLowPopulation(1): Bundle low population, " +
                                                             "Bundle.Adjacent(2): Neighbor low population bundle.", ShortName = "bundle")]
        public Bundle Bundling = Bundle.None;
 
        // REVIEW: Different default from TLC FR. I prefer the TLC FR default of 255.
        // REVIEW: Reverting back to 255 to make the same defaults of FR.
        /// <summary>
        /// Maximum number of distinct values (bins) per feature.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Maximum number of distinct values (bins) per feature", ShortName = "mb")]
        public int MaximumBinCountPerFeature = 255;  // save one for undefs
 
        /// <summary>
        /// Sparsity level needed to use sparse feature representation.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Sparsity level needed to use sparse feature representation", ShortName = "sp")]
        public Double SparsifyThreshold = 0.7;
 
        /// <summary>
        /// The feature first use penalty coefficient.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The feature first use penalty coefficient", ShortName = "ffup")]
        public Double FeatureFirstUsePenalty;
 
        /// <summary>
        /// The feature re-use penalty (regularization) coefficient.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The feature re-use penalty (regularization) coefficient", ShortName = "frup")]
        public Double FeatureReusePenalty;
 
        /// <summary>
        /// Tree fitting gain confidence requirement. Only consider a gain if its likelihood versus a random choice gain is above this value.
        /// </summary>
        /// <value>
        /// Value of 0.95 would mean restricting to gains that have less than a 0.05 chance of being generated randomly through choice of a random split.
        /// Valid range is [0,1).
        /// </value>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Tree fitting gain confidence requirement (should be in the range [0,1) ).", ShortName = "gainconf")]
        public Double GainConfidenceLevel;
 
        /// <summary>
        /// The temperature of the randomized softmax distribution for choosing the feature.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "The temperature of the randomized softmax distribution for choosing the feature", ShortName = "smtemp")]
        public Double SoftmaxTemperature;
 
        /// <summary>
        /// Print execution time breakdown to ML.NET channel.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Print execution time breakdown to stdout", ShortName = "et")]
        public bool ExecutionTime;
 
        /// <summary>
        /// Print memory statistics to ML.NET channel.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Print memory statistics to stdout", ShortName = "memstats")]
        public bool MemoryStatistics = true;
 
        // REVIEW: Different from original FastRank arguments (shortname l vs. nl). Different default from TLC FR Wrapper (20 vs. 20).
        /// <summary>
        /// The max number of leaves in each regression tree.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The max number of leaves in each regression tree", ShortName = "nl", SortOrder = 2)]
        [TGUI(Description = "The maximum number of leaves per tree", SuggestedSweeps = "2-128;log;inc:4")]
        [TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale: true, stepSize: 4)]
        public int NumberOfLeaves = Defaults.NumberOfLeaves;
 
        /// <summary>
        /// The minimal number of data points required to form a new tree leaf.
        /// </summary>
        // REVIEW: Arrays not supported in GUI
        // REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The minimal number of examples allowed in a leaf of a regression tree, out of the subsampled data", ShortName = "mil", SortOrder = 3)]
        [TGUI(Description = "Minimum number of training instances required to form a leaf", SuggestedSweeps = "1,10,50")]
        [TlcModule.SweepableDiscreteParamAttribute("MinDocumentsInLeafs", new object[] { 1, 10, 50 })]
        public int MinimumExampleCountPerLeaf = Defaults.MinimumExampleCountPerLeaf;
 
        /// <summary>
        /// Total number of decision trees to create in the ensemble.
        /// </summary>
        // REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Total number of decision trees to create in the ensemble", ShortName = "iter", SortOrder = 1)]
        [TGUI(Description = "Total number of trees constructed", SuggestedSweeps = "20,100,500")]
        [TlcModule.SweepableDiscreteParamAttribute("NumTrees", new object[] { 20, 100, 500 })]
        public int NumberOfTrees = Defaults.NumberOfTrees;
 
        /// <summary>
        /// The fraction of features (chosen randomly) to use on each iteration. Use 0.9 if only 90% of features is needed.
        /// Lower numbers help reduce over-fitting.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "The fraction of features (chosen randomly) to use on each iteration", ShortName = "ff")]
        public Double FeatureFraction = 1;
 
        /// <summary>
        /// Number of trees in each bag (0 for disabling bagging).
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Number of trees in each bag (0 for disabling bagging)", ShortName = "bag")]
        public int BaggingSize;
 
        /// <summary>
        /// Percentage of training examples used in each bag. Default is 0.7 (70%).
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Percentage of training examples used in each bag", ShortName = "bagfrac")]
        // REVIEW: sweeping bagfrac doesn't make sense unless 'baggingSize' is non-zero. The 'SuggestedSweeps' here
        // are used to denote 'sensible range', but the GUI will interpret this as 'you must sweep these values'. So, I'm keeping
        // the values there for the future, when we have an appropriate way to encode this information.
        // [TGUI(SuggestedSweeps = "0.5,0.7,0.9")]
        public Double BaggingExampleFraction = 0.7;
 
        /// <summary>
        /// The fraction of features (chosen randomly) to use on each split. If it's value is 0.9, 90% of all features would be dropped in expectation.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "The fraction of features (chosen randomly) to use on each split", ShortName = "sf")]
        public Double FeatureFractionPerSplit = 1;
 
        /// <summary>
        /// Smoothing parameter for tree regularization.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Smoothing paramter for tree regularization", ShortName = "s")]
        public Double Smoothing;
 
        /// <summary>
        /// When a root split is impossible, allow training to proceed.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "When a root split is impossible, allow training to proceed", ShortName = "allowempty,dummies", Hide = true)]
        [TGUI(NotGui = true)]
        public bool AllowEmptyTrees = true;
 
        /// <summary>
        /// The level of feature compression to use.
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The level of feature compression to use", ShortName = "fcomp", Hide = true)]
        [TGUI(NotGui = true)]
        internal int FeatureCompressionLevel = 1;
 
        /// <summary>
        /// Compress the tree Ensemble.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Compress the tree Ensemble", ShortName = "cmp", Hide = true)]
        [TGUI(NotGui = true)]
        public bool CompressEnsemble;
 
        /// <summary>
        /// Print metrics graph for the first test set.
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Print metrics graph for the first test set", ShortName = "graph", Hide = true)]
        [TGUI(NotGui = true)]
        internal bool PrintTestGraph;
 
        /// <summary>
        /// Print Train and Validation metrics in graph.
        /// </summary>
        //It is only enabled if printTestGraph is also set
        [BestFriend]
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Print Train and Validation metrics in graph", ShortName = "graphtv", Hide = true)]
        [TGUI(NotGui = true)]
        internal bool PrintTrainValidGraph;
 
        /// <summary>
        /// Calculate metric values for train/valid/test every k rounds.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Calculate metric values for train/valid/test every k rounds", ShortName = "tf")]
        public int TestFrequency = int.MaxValue;
 
        internal virtual void Check(IExceptionContext ectx)
        {
            Contracts.AssertValue(ectx);
            ectx.CheckUserArg(NumberOfThreads == null || NumberOfThreads > 0, nameof(NumberOfThreads), "Must be positive.");
            ectx.CheckUserArg(NumberOfLeaves >= 2, nameof(NumberOfLeaves), "Must be at least 2.");
            ectx.CheckUserArg(0 <= EntropyCoefficient && EntropyCoefficient <= 1, nameof(EntropyCoefficient), "Must be between 0 and 1.");
            ectx.CheckUserArg(0 <= GainConfidenceLevel && GainConfidenceLevel < 1, nameof(GainConfidenceLevel), "Must be in [0, 1).");
            ectx.CheckUserArg(0 <= FeatureFraction && FeatureFraction <= 1, nameof(FeatureFraction), "Must be between 0 and 1.");
            ectx.CheckUserArg(0 <= FeatureFractionPerSplit && FeatureFractionPerSplit <= 1, nameof(FeatureFractionPerSplit), "Must be between 0 and 1.");
            ectx.CheckUserArg(0 <= SoftmaxTemperature, nameof(SoftmaxTemperature), "Must be non-negative.");
            ectx.CheckUserArg(0 < MaximumBinCountPerFeature, nameof(MaximumBinCountPerFeature), "Must greater than 0.");
            ectx.CheckUserArg(0 <= SparsifyThreshold && SparsifyThreshold <= 1, nameof(SparsifyThreshold), "Must be between 0 and 1.");
            ectx.CheckUserArg(0 < NumberOfTrees, nameof(NumberOfTrees), "Must be positive.");
            ectx.CheckUserArg(0 <= Smoothing && Smoothing <= 1, nameof(Smoothing), "Must be between 0 and 1.");
            ectx.CheckUserArg(0 <= BaggingSize, nameof(BaggingSize), "Must be non-negative.");
            ectx.CheckUserArg(0 <= BaggingExampleFraction && BaggingExampleFraction <= 1, nameof(BaggingExampleFraction), "Must be between 0 and 1.");
            ectx.CheckUserArg(0 <= FeatureFirstUsePenalty, nameof(FeatureFirstUsePenalty), "Must be non-negative.");
            ectx.CheckUserArg(0 <= FeatureReusePenalty, nameof(FeatureReusePenalty), "Must be non-negative.");
            ectx.CheckUserArg(0 <= MaximumCategoricalGroupCountPerNode, nameof(MaximumCategoricalGroupCountPerNode), "Must be non-negative.");
            ectx.CheckUserArg(0 <= MaximumCategoricalSplitPointCount, nameof(MaximumCategoricalSplitPointCount), "Must be non-negative.");
            ectx.CheckUserArg(0 <= MinimumExampleFractionForCategoricalSplit, nameof(MinimumExampleFractionForCategoricalSplit), "Must be non-negative.");
            ectx.CheckUserArg(0 <= MinimumExamplesForCategoricalSplit, nameof(MinimumExamplesForCategoricalSplit), "Must be non-negative.");
            ectx.CheckUserArg(Bundle.None <= Bundling && Bundling <= Bundle.Adjacent, nameof(Bundling), "Must be between 0 and 2.");
            ectx.CheckUserArg(Bias >= 0, nameof(Bias), "Must be greater than equal to zero.");
        }
    }
 
    /// <summary>
    /// Options for boosting tree trainers.
    /// </summary>
    public abstract class BoostedTreeOptions : TreeOptions
    {
        // REVIEW: TLC FR likes to call it bestStepRegressionTrees which might be more appropriate.
        //Use the second derivative for split gains (not just outputs). Use MaxTreeOutput to "clip" cases where the second derivative is too close to zero.
        //Turning BSR on makes larger steps in initial stages and converges to better results with fewer trees (though in the end, it asymptotes to the same results).
        /// <summary>
        /// Option for using best regression step trees.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Option for using best regression step trees", ShortName = "bsr")]
        public bool BestStepRankingRegressionTrees = false;
 
        /// <summary>
        /// Determines whether to use line search for a step size.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Should we use line search for a step size", ShortName = "ls")]
        public bool UseLineSearch;
 
        /// <summary>
        /// Number of post-bracket line search steps.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Number of post-bracket line search steps", ShortName = "lssteps")]
        public int MaximumNumberOfLineSearchSteps;
 
        /// <summary>
        /// Minimum line search step size.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Minimum line search step size", ShortName = "minstep")]
        public Double MinimumStepSize;
 
        /// <summary>
        /// Types of optimization algorithms.
        /// </summary>
        public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDescent, ConjugateGradientDescent };
 
        /// <summary>
        /// Optimization algorithm to be used.
        /// </summary>
        /// <value>
        /// See <see cref="OptimizationAlgorithmType"/> for available optimizers.
        /// </value>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent)", ShortName = "oa")]
        public OptimizationAlgorithmType OptimizationAlgorithm = OptimizationAlgorithmType.GradientDescent;
 
        /// <summary>
        /// Early stopping rule. (Validation set (/valid) is required).
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", Name = "EarlyStoppingRule", ShortName = "esr", NullName = "<Disable>")]
        [TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")]
        internal IEarlyStoppingCriterionFactory EarlyStoppingRuleFactory;
 
        /// <summary>
        /// The underlying state of <see cref="EarlyStoppingRuleFactory"/> and <see cref="EarlyStoppingRule"/>.
        /// </summary>
        private EarlyStoppingRuleBase _earlyStoppingRuleBase;
 
        /// <summary>
        /// Early stopping rule used to terminate training process once meeting a specified criterion. Possible choices are
        /// <see cref="EarlyStoppingRuleBase"/>'s implementations such as <see cref="TolerantEarlyStoppingRule"/> and <see cref="GeneralityLossRule"/>.
        /// </summary>
        public EarlyStoppingRuleBase EarlyStoppingRule
        {
            get { return _earlyStoppingRuleBase; }
            set
            {
                _earlyStoppingRuleBase = value;
                EarlyStoppingRuleFactory = _earlyStoppingRuleBase.BuildFactory();
            }
        }
 
        /// <summary>
        /// Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3).
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3)", ShortName = "esmt")]
        [TGUI(Description = "Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3)")]
        internal int EarlyStoppingMetrics;
 
        /// <summary>
        /// Enable post-training tree pruning to avoid overfitting. It requires a validation set.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Enable post-training pruning to avoid overfitting. (a validation set is required)", ShortName = "pruning")]
        public bool EnablePruning;
 
        /// <summary>
        /// Use window and tolerance for pruning.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Use window and tolerance for pruning", ShortName = "prtol")]
        public bool UseTolerantPruning;
 
        /// <summary>
        /// The tolerance threshold for pruning.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "The tolerance threshold for pruning", ShortName = "prth")]
        [TGUI(Description = "Pruning threshold")]
        public double PruningThreshold = 0.004;
 
        /// <summary>
        /// The moving window size for pruning.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "The moving window size for pruning", ShortName = "prws")]
        [TGUI(Description = "Pruning window size")]
        public int PruningWindowSize = 5;
 
        /// <summary>
        /// The learning rate.
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The learning rate", ShortName = "lr", SortOrder = 4)]
        [TGUI(Label = "Learning Rate", SuggestedSweeps = "0.025-0.4;log")]
        [TlcModule.SweepableFloatParamAttribute("LearningRates", 0.025f, 0.4f, isLogScale: true)]
        public double LearningRate = Defaults.LearningRate;
 
        /// <summary>
        /// Shrinkage.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Shrinkage", ShortName = "shrk")]
        [TGUI(Label = "Shrinkage", SuggestedSweeps = "0.25-4;log")]
        [TlcModule.SweepableFloatParamAttribute("Shrinkage", 0.025f, 4f, isLogScale: true)]
        public Double Shrinkage = 1;
 
        /// <summary>
        /// Dropout rate for tree regularization.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Dropout rate for tree regularization", ShortName = "tdrop")]
        [TGUI(SuggestedSweeps = "0,0.000000001,0.05,0.1,0.2")]
        [TlcModule.SweepableDiscreteParamAttribute("DropoutRate", new object[] { 0.0f, 1E-9f, 0.05f, 0.1f, 0.2f })]
        public Double DropoutRate = 0;
 
        /// <summary>
        /// Sample each query 1 in k times in the GetDerivatives function.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Sample each query 1 in k times in the GetDerivatives function", ShortName = "sr")]
        public int GetDerivativesSampleRate = 1;
 
        /// <summary>
        /// Write the last ensemble instead of the one determined by early stopping.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Write the last ensemble instead of the one determined by early stopping", ShortName = "hl")]
        public bool WriteLastEnsemble;
 
        /// <summary>
        /// Upper bound on absolute value of single tree output.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Upper bound on absolute value of single tree output", ShortName = "mo")]
        public Double MaximumTreeOutput = 100;
 
        /// <summary>
        /// Training starts from random ordering (determined by /r1).
        /// </summary>
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Training starts from random ordering (determined by /r1)", ShortName = "rs", Hide = true)]
        [TGUI(NotGui = true)]
        public bool RandomStart;
 
        /// <summary>
        /// Filter zero lambdas during training.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Filter zero lambdas during training", ShortName = "fzl", Hide = true)]
        [TGUI(NotGui = true)]
        public bool FilterZeroLambdas;
 
#if OLD_DATALOAD
        [Argument(ArgumentType.AtMostOnce, HelpText = "The proportion of the lambdas that should be secondary metrics", ShortName = "sfrac", Hide = true)]
        [TGUI(NotGUI = true)]
        public Double secondaryMetricShare;
 
        [Argument(ArgumentType.LastOccurenceWins, HelpText = "Secondary lambdas by default are calculated for all pairs; this makes them calculated only for those pairs with identical labels",
            ShortName = "secondexclusive", Hide = true)]
        [TGUI(NotGUI = true)]
        public bool secondaryIsolabelExclusive;
 
        [Argument(ArgumentType.LastOccurenceWins, HelpText = "Force garbage collection during feature extraction each time this many features are read", ShortName = "gcfe", Hide = true)]
        [TGUI(NotGUI = true)]
        public int forceGCFeatureExtraction = 100;
#endif
 
        /// <summary>
        /// Freeform defining the scores that should be used as the baseline ranker.
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Freeform defining the scores that should be used as the baseline ranker", ShortName = "basescores", Hide = true)]
        [TGUI(NotGui = true)]
        internal string BaselineScoresFormula;
 
        /// <summary>
        /// Baseline alpha for tradeoffs of risk (0 is normal training).
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Baseline alpha for tradeoffs of risk (0 is normal training)", ShortName = "basealpha", Hide = true)]
        [TGUI(NotGui = true)]
        internal string BaselineAlphaRisk;
 
        /// <summary>
        /// The discount freeform which specifies the per position discounts of examples in a query (uses a single variable P for position where P=0 is first position).
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "The discount freeform which specifies the per position discounts of examples in a query (uses a single variable P for position where P=0 is first position)",
            ShortName = "pdff", Hide = true)]
        [TGUI(NotGui = true)]
        internal string PositionDiscountFreeform;
 
#if !NO_STORE
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Offload feature bins to a file store", ShortName = "fbsopt", Hide = true)]
        [TGUI(NotGUI = true)]
        public bool offloadBinsToFileStore;
 
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Directory used to offload feature bins", ShortName = "fbsoptdir", Hide = true)]
        [TGUI(NotGUI = true)]
        public string offloadBinsDirectory = string.Empty;
 
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Preloads feature bins needed for the next iteration when bins file store is used", ShortName = "fbsoptpreload", Hide = true)]
        [TGUI(NotGUI = true)]
        public bool preloadFeatureBinsBeforeTraining;
#endif
 
        internal override void Check(IExceptionContext ectx)
        {
            base.Check(ectx);
 
            ectx.CheckUserArg(0 <= MaximumTreeOutput, nameof(MaximumTreeOutput), "Must be non-negative.");
            ectx.CheckUserArg(0 <= PruningThreshold, nameof(PruningThreshold), "Must be non-negative.");
            ectx.CheckUserArg(0 < PruningWindowSize, nameof(PruningWindowSize), "Must be positive.");
            ectx.CheckUserArg(0 < Shrinkage, nameof(Shrinkage), "Must be positive.");
            ectx.CheckUserArg(0 <= DropoutRate && DropoutRate <= 1, nameof(DropoutRate), "Must be between 0 and 1.");
            ectx.CheckUserArg(0 < GetDerivativesSampleRate, nameof(GetDerivativesSampleRate), "Must be positive.");
            ectx.CheckUserArg(0 <= MaximumNumberOfLineSearchSteps, nameof(MaximumNumberOfLineSearchSteps), "Must be non-negative.");
            ectx.CheckUserArg(0 <= MinimumStepSize, nameof(MinimumStepSize), "Must be non-negative.");
        }
    }
}