File: LightGbmArguments.cs
Web Access
Project: src\src\Microsoft.ML.LightGbm\Microsoft.ML.LightGbm.csproj (Microsoft.ML.LightGbm)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Reflection;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.LightGbm;
 
[assembly: LoadableClass(typeof(GradientBooster), typeof(GradientBooster.Options),
    typeof(SignatureLightGBMBooster), GradientBooster.FriendlyName, GradientBooster.Name)]
[assembly: LoadableClass(typeof(DartBooster), typeof(DartBooster.Options),
    typeof(SignatureLightGBMBooster), DartBooster.FriendlyName, DartBooster.Name)]
[assembly: LoadableClass(typeof(GossBooster), typeof(GossBooster.Options),
    typeof(SignatureLightGBMBooster), GossBooster.FriendlyName, GossBooster.Name)]
 
[assembly: EntryPointModule(typeof(GradientBooster.Options))]
[assembly: EntryPointModule(typeof(DartBooster.Options))]
[assembly: EntryPointModule(typeof(GossBooster.Options))]
 
namespace Microsoft.ML.Trainers.LightGbm
{
    internal delegate void SignatureLightGBMBooster();
 
    [TlcModule.ComponentKind("BoosterParameterFunction")]
    internal interface IBoosterParameterFactory : IComponentFactory<BoosterParameterBase>
    {
        new BoosterParameterBase CreateComponent(IHostEnvironment env);
    }
 
    public abstract class BoosterParameterBase
    {
        private protected static Dictionary<string, string> NameMapping = new Dictionary<string, string>()
        {
           {nameof(OptionsBase.MinimumSplitGain),               "min_split_gain" },
           {nameof(OptionsBase.MaximumTreeDepth),               "max_depth"},
           {nameof(OptionsBase.MinimumChildWeight),             "min_child_weight"},
           {nameof(OptionsBase.SubsampleFraction),              "subsample"},
           {nameof(OptionsBase.SubsampleFrequency),             "subsample_freq"},
           {nameof(OptionsBase.L1Regularization),               "reg_alpha"},
           {nameof(OptionsBase.L2Regularization),               "reg_lambda"},
        };
        public BoosterParameterBase(OptionsBase options)
        {
            Contracts.CheckUserArg(options.MinimumSplitGain >= 0, nameof(OptionsBase.MinimumSplitGain), "must be >= 0.");
            Contracts.CheckUserArg(options.MinimumChildWeight >= 0, nameof(OptionsBase.MinimumChildWeight), "must be >= 0.");
            Contracts.CheckUserArg(options.SubsampleFraction > 0 && options.SubsampleFraction <= 1, nameof(OptionsBase.SubsampleFraction), "must be in (0,1].");
            Contracts.CheckUserArg(options.FeatureFraction > 0 && options.FeatureFraction <= 1, nameof(OptionsBase.FeatureFraction), "must be in (0,1].");
            Contracts.CheckUserArg(options.L2Regularization >= 0, nameof(OptionsBase.L2Regularization), "must be >= 0.");
            Contracts.CheckUserArg(options.L1Regularization >= 0, nameof(OptionsBase.L1Regularization), "must be >= 0.");
            BoosterOptions = options;
        }
 
        public abstract class OptionsBase : IBoosterParameterFactory
        {
            internal BoosterParameterBase GetBooster() { return null; }
            /// <summary>
            /// The minimum loss reduction required to make a further partition on a leaf node of the tree.
            /// </summary>
            /// <value>
            /// Larger values make the algorithm more conservative.
            /// </value>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Minimum loss reduction required to make a further partition on a leaf node of the tree. the larger, " +
                    "the more conservative the algorithm will be.")]
            [TlcModule.Range(Min = 0.0)]
            public double MinimumSplitGain = 0;
 
            /// <summary>
            /// The maximum depth of a tree.
            /// </summary>
            /// <value>
            /// 0 means no limit.
            /// </value>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Maximum depth of a tree. 0 means no limit. However, tree still grows by best-first.")]
            [TlcModule.Range(Min = 0, Max = int.MaxValue)]
            public int MaximumTreeDepth = 0;
 
            /// <summary>
            /// The minimum sum of instance weight needed to form a new node.
            /// </summary>
            /// <value>
            /// If the tree partition step results in a leaf node with the sum of instance weight less than <see cref="MinimumChildWeight"/>,
            /// the building process will give up further partitioning. In linear regression mode, this simply corresponds to minimum number
            /// of instances needed to be in each node. The larger, the more conservative the algorithm will be.
            /// </value>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Minimum sum of instance weight(hessian) needed in a child. If the tree partition step results in a leaf " +
                    "node with the sum of instance weight less than min_child_weight, then the building process will give up further partitioning. In linear regression mode, " +
                    "this simply corresponds to minimum number of instances needed to be in each node. The larger, the more conservative the algorithm will be.")]
            [TlcModule.Range(Min = 0.0)]
            public double MinimumChildWeight = 0.1;
 
            /// <summary>
            /// The frequency of performing subsampling (bagging).
            /// </summary>
            /// <value>
            /// 0 means disable bagging. N means perform bagging at every N iterations.
            /// To enable bagging, <see cref="SubsampleFraction"/> should also be set to a value less than 1.0.
            /// </value>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Subsample frequency for bagging. 0 means no subsample. "
                + "Specifies the frequency at which the bagging occurs, where if this is set to N, the subsampling will happen at every N iterations." +
                "This must be set with Subsample as this specifies the amount to subsample.")]
            [TlcModule.Range(Min = 0, Max = int.MaxValue)]
            public int SubsampleFrequency = 0;
 
            /// <summary>
            /// The fraction of training data used for creating trees.
            /// </summary>
            /// <value>
            /// Setting it to 0.5 means that LightGBM randomly picks half of the data points to grow trees.
            /// This can be used to speed up training and to reduce over-fitting. Valid range is (0,1].
            /// </value>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Subsample ratio of the training instance. Setting it to 0.5 means that LightGBM randomly collected " +
                    "half of the data instances to grow trees and this will prevent overfitting. Range: (0,1].")]
            [TlcModule.Range(Inf = 0.0, Max = 1.0)]
            public double SubsampleFraction = 1;
 
            /// <summary>
            /// The fraction of features used when creating trees.
            /// </summary>
            /// <value>
            /// If <see cref="FeatureFraction"/> is smaller than 1.0, LightGBM will randomly select fraction of features to train each tree.
            /// For example, if you set it to 0.8, LightGBM will select 80% of features before training each tree.
            /// This can be used to speed up training and to reduce over-fitting. Valid range is (0,1].
            /// </value>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Subsample ratio of columns when constructing each tree. Range: (0,1].",
                ShortName = "ff")]
            [TlcModule.Range(Inf = 0.0, Max = 1.0)]
            public double FeatureFraction = 1;
 
            /// <summary>
            /// The L2 regularization term on weights.
            /// </summary>
            /// <value>
            /// Increasing this value could help reduce over-fitting.
            /// </value>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "L2 regularization term on weights, increasing this value will make model more conservative.",
                ShortName = "l2")]
            [TlcModule.Range(Min = 0.0)]
            [TGUI(Label = "Lambda(L2)", SuggestedSweeps = "0,0.5,1")]
            [TlcModule.SweepableDiscreteParam("RegLambda", new object[] { 0f, 0.5f, 1f })]
            public double L2Regularization = 0.01;
 
            /// <summary>
            /// The L1 regularization term on weights.
            /// </summary>
            /// <value>
            /// Increasing this value could help reduce over-fitting.
            /// </value>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "L1 regularization term on weights, increase this value will make model more conservative.",
                ShortName = "l1")]
            [TlcModule.Range(Min = 0.0)]
            [TGUI(Label = "Alpha(L1)", SuggestedSweeps = "0,0.5,1")]
            [TlcModule.SweepableDiscreteParam("RegAlpha", new object[] { 0f, 0.5f, 1f })]
            public double L1Regularization = 0;
 
            BoosterParameterBase IComponentFactory<BoosterParameterBase>.CreateComponent(IHostEnvironment env)
                => BuildOptions();
 
            BoosterParameterBase IBoosterParameterFactory.CreateComponent(IHostEnvironment env)
                => BuildOptions();
 
            internal abstract BoosterParameterBase BuildOptions();
        }
 
        internal void UpdateParameters(Dictionary<string, object> res)
        {
            FieldInfo[] fields = BoosterOptions.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
            foreach (var field in fields)
            {
                var attribute = field.GetCustomAttribute<ArgumentAttribute>(false);
 
                if (attribute == null)
                    continue;
 
                var name = NameMapping.ContainsKey(field.Name) ? NameMapping[field.Name] : LightGbmInterfaceUtils.GetOptionName(field.Name);
                res[name] = field.GetValue(BoosterOptions);
            }
        }
 
        /// <summary>
        /// Create <see cref="IBoosterParameterFactory"/> for supporting legacy infra built upon <see cref="IComponentFactory"/>.
        /// </summary>
        internal abstract IBoosterParameterFactory BuildFactory();
        internal abstract string BoosterName { get; }
 
        private protected OptionsBase BoosterOptions;
    }
 
    /// <summary>
    /// Gradient boosting decision tree.
    /// </summary>
    /// <remarks>
    /// For details, please see <a href="https://en.wikipedia.org/wiki/Gradient_boosting#Gradient_tree_boosting">gradient tree boosting</a>.
    /// </remarks>
    public sealed class GradientBooster : BoosterParameterBase
    {
        internal const string Name = "gbdt";
        internal const string FriendlyName = "Tree Booster";
 
        /// <summary>
        /// The options for <see cref="GradientBooster"/>, used for setting <see cref="Booster"/>.
        /// </summary>
        [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Traditional Gradient Boosting Decision Tree.")]
        public sealed class Options : OptionsBase
        {
            internal override BoosterParameterBase BuildOptions() => new GradientBooster(this);
        }
 
        internal GradientBooster(Options options)
            : base(options)
        {
        }
 
        internal override IBoosterParameterFactory BuildFactory() => BoosterOptions;
 
        internal override string BoosterName => Name;
    }
 
    /// <summary>
    /// DART booster (Dropouts meet Multiple Additive Regression Trees)
    /// </summary>
    /// <remarks>
    /// For details, please see <a href="https://arxiv.org/abs/1505.01866">here</a>.
    /// </remarks>
    public sealed class DartBooster : BoosterParameterBase
    {
        internal const string Name = "dart";
        internal const string FriendlyName = "Tree Dropout Tree Booster";
 
        /// <summary>
        /// The options for <see cref="DartBooster"/>, used for setting <see cref="Booster"/>.
        /// </summary>
        [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Dropouts meet Multiple Additive Regresion Trees. See https://arxiv.org/abs/1505.01866")]
        public sealed class Options : OptionsBase
        {
            static Options()
            {
                // Add additional name mappings
                NameMapping.Add(nameof(TreeDropFraction), "drop_rate");
                NameMapping.Add(nameof(MaximumNumberOfDroppedTreesPerRound), "max_drop");
                NameMapping.Add(nameof(SkipDropFraction), "skip_drop");
            }
 
            /// <summary>
            /// The dropout rate, i.e. the fraction of previous trees to drop during the dropout.
            /// </summary>
            /// <value>
            /// Valid range is [0,1].
            /// </value>
            [Argument(ArgumentType.AtMostOnce, HelpText = "The drop ratio for trees. Range:(0,1).")]
            [TlcModule.Range(Inf = 0.0, Max = 1.0)]
            public double TreeDropFraction = 0.1;
 
            /// <summary>
            /// The maximum number of dropped trees in a boosting round.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of dropped trees in a boosting round.")]
            [TlcModule.Range(Inf = 0, Max = int.MaxValue)]
            public int MaximumNumberOfDroppedTreesPerRound = 1;
 
            /// <summary>
            /// The probability of skipping the dropout procedure during a boosting iteration.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Probability for not dropping in a boosting round.")]
            [TlcModule.Range(Inf = 0.0, Max = 1.0)]
            public double SkipDropFraction = 0.5;
 
            /// <summary>
            /// Whether to enable xgboost dart mode.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "True will enable xgboost dart mode.")]
            public bool XgboostDartMode = false;
 
            /// <summary>
            /// Whether to enable uniform drop.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "True will enable uniform drop.")]
            public bool UniformDrop = false;
 
            internal override BoosterParameterBase BuildOptions() => new DartBooster(this);
        }
 
        internal DartBooster(Options options)
            : base(options)
        {
            Contracts.CheckUserArg(options.TreeDropFraction > 0 && options.TreeDropFraction < 1, nameof(options.TreeDropFraction), "must be in (0,1).");
            Contracts.CheckUserArg(options.SkipDropFraction >= 0 && options.SkipDropFraction < 1, nameof(options.SkipDropFraction), "must be in [0,1).");
            BoosterOptions = options;
        }
 
        internal override IBoosterParameterFactory BuildFactory() => BoosterOptions;
        internal override string BoosterName => Name;
    }
 
    /// <summary>
    /// Gradient-based One-Side Sampling booster.
    /// </summary>
    /// <remarks>
    /// For details, please see <a href="https://papers.nips.cc/paper/6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree.pdf">here</a>.
    /// </remarks>
    public sealed class GossBooster : BoosterParameterBase
    {
        internal const string Name = "goss";
        internal const string FriendlyName = "Gradient-based One-Size Sampling";
 
        /// <summary>
        /// The options for <see cref="GossBooster"/>, used for setting <see cref="Booster"/>.
        /// </summary>
        [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Gradient-based One-Side Sampling.")]
        public sealed class Options : OptionsBase
        {
            /// <summary>
            /// The retain ratio of large gradient data.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Retain ratio for large gradient instances.")]
            [TlcModule.Range(Inf = 0.0, Max = 1.0)]
            public double TopRate = 0.2;
 
            /// <summary>
            /// The retain ratio of small gradient data.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Retain ratio for small gradient instances.")]
            [TlcModule.Range(Inf = 0.0, Max = 1.0)]
            public double OtherRate = 0.1;
 
            internal override BoosterParameterBase BuildOptions() => new GossBooster(this);
        }
 
        internal GossBooster(Options options)
            : base(options)
        {
            Contracts.CheckUserArg(options.TopRate > 0 && options.TopRate < 1, nameof(Options.TopRate), "must be in (0,1).");
            Contracts.CheckUserArg(options.OtherRate >= 0 && options.OtherRate < 1, nameof(Options.OtherRate), "must be in [0,1).");
            Contracts.Check(options.TopRate + options.OtherRate <= 1, "Sum of topRate and otherRate cannot be larger than 1.");
            BoosterOptions = options;
        }
 
        internal override IBoosterParameterFactory BuildFactory() => BoosterOptions;
        internal override string BoosterName => Name;
    }
}