File: LightGbmMulticlassTrainer.cs
Web Access
Project: src\src\Microsoft.ML.LightGbm\Microsoft.ML.LightGbm.csproj (Microsoft.ML.LightGbm)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Trainers.LightGbm;
 
[assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(LightGbmMulticlassTrainer.Options),
    new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) },
    "LightGBM Multi-class Classifier", LightGbmMulticlassTrainer.LoadNameValue, LightGbmMulticlassTrainer.ShortName, DocName = "trainer/LightGBM.md")]
 
namespace Microsoft.ML.Trainers.LightGbm
{
    /// <summary>
    /// The <see cref="IEstimator{TTransformer}"/> for training a boosted decision tree multi-class classification model using LightGBM.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    /// To create this trainer, use [LightGbm](xref:Microsoft.ML.LightGbmExtensions.LightGbm(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,System.Nullable{System.Int32},System.Nullable{System.Int32},System.Nullable{System.Double},System.Int32))
    /// or [LightGbm(Options)](xref:Microsoft.ML.LightGbmExtensions.LightGbm(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,Microsoft.ML.Trainers.LightGbm.LightGbmMulticlassTrainer.Options)).
    ///
    /// [!include[io](~/../docs/samples/docs/api-reference/io-columns-multiclass-classification.md)]
    ///
    /// ### Trainer Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Machine learning task | Multiclass classification |
    /// | Is normalization required? | No |
    /// | Is caching required? | No |
    /// | Required NuGet in addition to Microsoft.ML | Microsoft.ML.LightGbm |
    /// | Exportable to ONNX | Yes |
    ///
    /// [!include[algorithm](~/../docs/samples/docs/api-reference/algo-details-lightgbm.md)]
    /// ]]>
    /// </format>
    /// </remarks>
    /// <seealso cref="LightGbmExtensions.LightGbm(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, int?, int?, double?, int)"/>
    /// <seealso cref="LightGbmExtensions.LightGbm(MulticlassClassificationCatalog.MulticlassClassificationTrainers, LightGbmMulticlassTrainer.Options)"/>
    /// <seealso cref="Options"/>
    public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase<LightGbmMulticlassTrainer.Options,
                                                                        VBuffer<float>,
                                                                        MulticlassPredictionTransformer<OneVersusAllModelParameters>,
                                                                        OneVersusAllModelParameters>
    {
        internal const string Summary = "LightGBM Multi Class Classifier";
        internal const string LoadNameValue = "LightGBMMulticlass";
        internal const string ShortName = "LightGBMMC";
        private const int _minDataToUseSoftmax = 50000;
 
        private const double _maxNumClass = 1e6;
 
        // If there are NaN labels, they are converted to be equal to _numberOfClassesIncludingNan - 1.
        // This is done because NaN labels are going to be seen as an extra different class, when training the model in the WrappedLightGbmTraining class
        // But, when creating the Predictors, only _numberOfClasses is considered, ignoring the "extra class" of NaN labels.
        private int _numberOfClassesIncludingNan;
        private int _numberOfClasses;
 
        private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
 
        /// <summary>
        /// Options for the <see cref="LightGbmMulticlassTrainer"/> as used in
        ///  [LightGbm(Options)](xref:Microsoft.ML.LightGbmExtensions.LightGbm(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,Microsoft.ML.Trainers.LightGbm.LightGbmMulticlassTrainer.Options)).
        /// </summary>
        public sealed class Options : OptionsBase
        {
            public enum EvaluateMetricType
            {
                None,
                Default,
                Error,
                LogLoss,
            }
 
            /// <summary>
            /// Whether training data is unbalanced.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Use for multi-class classification when training data is not balanced", ShortName = "us")]
            public bool UnbalancedSets = false;
 
            /// <summary>
            /// Whether to use softmax loss.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Use softmax loss for the multi classification.")]
            [TlcModule.SweepableDiscreteParam("UseSoftmax", new object[] { true, false })]
            public bool? UseSoftmax;
 
            /// <summary>
            /// Parameter for the sigmoid function.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Parameter for the sigmoid function.", ShortName = "sigmoid")]
            [TGUI(Label = "Sigmoid", SuggestedSweeps = "0.5,1")]
            public double Sigmoid = 0.5;
 
            /// <summary>
            /// Determines what evaluation metric to use.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Evaluation metrics.",
                ShortName = "em")]
            public EvaluateMetricType EvaluationMetric = EvaluateMetricType.Error;
 
            static Options()
            {
                NameMapping.Add(nameof(EvaluateMetricType), "metric");
                NameMapping.Add(nameof(EvaluateMetricType.None), "None");
                NameMapping.Add(nameof(EvaluateMetricType.Default), "");
                NameMapping.Add(nameof(EvaluateMetricType.Error), "multi_error");
                NameMapping.Add(nameof(EvaluateMetricType.LogLoss), "multi_logloss");
            }
 
            internal override Dictionary<string, object> ToDictionary(IHost host)
            {
                var res = base.ToDictionary(host);
 
                res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets;
                res[GetOptionName(nameof(Sigmoid))] = Sigmoid;
                res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
 
                return res;
            }
        }
 
        internal LightGbmMulticlassTrainer(IHostEnvironment env, Options options)
             : base(env, LoadNameValue, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName))
        {
            Contracts.CheckUserArg(options.Sigmoid > 0, nameof(Options.Sigmoid), "must be > 0.");
            _numberOfClassesIncludingNan = -1;
        }
 
        /// <summary>
        /// Initializes a new instance of <see cref="LightGbmMulticlassTrainer"/>
        /// </summary>
        /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
        /// <param name="labelColumnName">The name of The label column.</param>
        /// <param name="featureColumnName">The name of the feature column.</param>
        /// <param name="exampleWeightColumnName">The name of the example weight column (optional).</param>
        /// <param name="numberOfLeaves">The number of leaves to use.</param>
        /// <param name="minimumExampleCountPerLeaf">The minimal number of data points allowed in a leaf of the tree, out of the subsampled data.</param>
        /// <param name="learningRate">The learning rate.</param>
        /// <param name="numberOfIterations">The number of iterations to use.</param>
        internal LightGbmMulticlassTrainer(IHostEnvironment env,
            string labelColumnName = DefaultColumnNames.Label,
            string featureColumnName = DefaultColumnNames.Features,
            string exampleWeightColumnName = null,
            int? numberOfLeaves = null,
            int? minimumExampleCountPerLeaf = null,
            double? learningRate = null,
            int numberOfIterations = Defaults.NumberOfIterations)
            : this(env,
                  new Options()
                  {
                      LabelColumnName = labelColumnName,
                      FeatureColumnName = featureColumnName,
                      ExampleWeightColumnName = exampleWeightColumnName,
                      NumberOfLeaves = numberOfLeaves,
                      MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf,
                      LearningRate = learningRate,
                      NumberOfIterations = numberOfIterations
                  })
        {
        }
 
        /// <summary>
        /// Initializes a new instance of <see cref="LightGbmRankingTrainer"/>
        /// </summary>
        /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
        /// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
        /// <param name="featureColumnName">The name of the feature column.</param>
        internal LightGbmMulticlassTrainer(IHostEnvironment env,
            Stream lightGbmModel,
            string featureColumnName = DefaultColumnNames.Features)
            : base(env,
                  LoadNameValue,
                  new Options()
                  {
                      FeatureColumnName = featureColumnName,
                      LightGbmModel = lightGbmModel
                  },
                  new SchemaShape.Column())
        {
        }
 
        private InternalTreeEnsemble GetBinaryEnsemble(int classID)
        {
            var res = new InternalTreeEnsemble();
            for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numberOfClassesIncludingNan)
            {
                // Ignore dummy trees.
                if (TrainedEnsemble.GetTreeAt(i).NumLeaves > 1)
                    res.AddTree(TrainedEnsemble.GetTreeAt(i));
            }
            return res;
        }
 
        private LightGbmBinaryModelParameters CreateBinaryPredictor(int classID, string innerArgs)
        {
            return new LightGbmBinaryModelParameters(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs);
        }
 
        private protected override OneVersusAllModelParameters CreatePredictor()
        {
            Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete.");
 
            Host.Assert(_numberOfClassesIncludingNan > 1, "Must know the number of classes before creating a predictor.");
            Host.Assert(TrainedEnsemble.NumTrees % _numberOfClassesIncludingNan == 0, "Number of trees should be a multiple of number of classes.");
 
            var innerArgs = LightGbmInterfaceUtils.JoinParameters(GbmOptions);
            IPredictorProducing<float>[] predictors = new IPredictorProducing<float>[_numberOfClasses];
            for (int i = 0; i < _numberOfClasses; ++i)
            {
                var pred = CreateBinaryPredictor(i, innerArgs);
                var cali = new PlattCalibrator(Host, -LightGbmTrainerOptions.Sigmoid, 0);
                predictors[i] = new FeatureWeightsCalibratedModelParameters<LightGbmBinaryModelParameters, PlattCalibrator>(Host, pred, cali);
            }
            string obj = (string)GetGbmParameters()["objective"];
            if (obj == "multiclass")
                return OneVersusAllModelParameters.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors);
            else
                return OneVersusAllModelParameters.Create(Host, predictors);
        }
 
        private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
        {
            Host.AssertValue(ch);
            base.CheckDataValid(ch, data);
            // If using a pre-trained model file we don't need a label or group column
            if (LightGbmTrainerOptions.LightGbmModel == null)
            {
                var labelType = data.Schema.Label.Value.Type;
                if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
                {
                    throw ch.ExceptParam(nameof(data),
                        $"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be of unsigned int, boolean or float.");
                }
            }
        }
 
        private protected override void InitializeBeforeTraining()
        {
            _numberOfClassesIncludingNan = -1;
            _numberOfClasses = 0;
        }
 
        private protected override void AdditionalLoadPreTrainedModel(string modelText)
        {
            string[] lines = modelText.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
            // Jump to the "objective" value in the file. It's at the beginning.
            int i = 0;
            while (!lines[i].StartsWith("objective"))
                i++;
 
            // Format in the file is objective=multiclass num_class:4
            var split = lines[i].Split(' ');
            _numberOfClassesIncludingNan = int.Parse(split[1].Split(':')[1]);
            _numberOfClasses = _numberOfClassesIncludingNan;
        }
 
 
        private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels)
        {
            // Only initialize one time.
 
            if (_numberOfClassesIncludingNan < 0)
            {
                float minLabel = float.MaxValue;
                float maxLabel = float.MinValue;
                bool hasNaNLabel = false;
                foreach (var labelColumn in labels)
                {
                    if (float.IsNaN(labelColumn))
                        hasNaNLabel = true;
                    else
                    {
                        minLabel = Math.Min(minLabel, labelColumn);
                        maxLabel = Math.Max(maxLabel, labelColumn);
                    }
                }
                ch.CheckParam(minLabel >= 0, nameof(data), "Minimum value in label column cannot be negative");
                if (maxLabel >= _maxNumClass)
                    throw ch.ExceptParam(nameof(data), $"Maximum value {maxLabel} in label column exceeds {_maxNumClass}");
 
                if (data.Schema.Label.Value.Type is KeyDataViewType keyType)
                {
                    if (hasNaNLabel)
                        _numberOfClassesIncludingNan = keyType.GetCountAsInt32(Host) + 1;
                    else
                        _numberOfClassesIncludingNan = keyType.GetCountAsInt32(Host);
                    _numberOfClasses = keyType.GetCountAsInt32(Host);
                }
                else
                {
                    if (hasNaNLabel)
                        _numberOfClassesIncludingNan = (int)maxLabel + 2;
                    else
                        _numberOfClassesIncludingNan = (int)maxLabel + 1;
                    _numberOfClasses = (int)maxLabel + 1;
                }
            }
 
            float defaultLabel = _numberOfClassesIncludingNan - 1;
            for (int i = 0; i < labels.Length; ++i)
                if (float.IsNaN(labels[i]))
                    labels[i] = defaultLabel;
        }
 
        private protected override void GetDefaultParameters(IChannel ch, int numRow, bool hasCategorical, int totalCats, bool hiddenMsg = false)
        {
            base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true);
            int numberOfLeaves = (int)GbmOptions["num_leaves"];
            int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, _numberOfClassesIncludingNan);
            GbmOptions["min_data_per_leaf"] = minimumExampleCountPerLeaf;
            if (!hiddenMsg)
            {
                if (!LightGbmTrainerOptions.LearningRate.HasValue)
                    ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.LearningRate) + " = " + GbmOptions["learning_rate"]);
                if (!LightGbmTrainerOptions.NumberOfLeaves.HasValue)
                    ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.NumberOfLeaves) + " = " + numberOfLeaves);
                if (!LightGbmTrainerOptions.MinimumExampleCountPerLeaf.HasValue)
                    ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.MinimumExampleCountPerLeaf) + " = " + minimumExampleCountPerLeaf);
            }
        }
 
        private protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, RoleMappedData data, float[] labels, int[] groups)
        {
            Host.AssertValue(ch);
            ch.Assert(PredictionKind == PredictionKind.MulticlassClassification);
            ch.Assert(_numberOfClassesIncludingNan > 1);
            GbmOptions["num_class"] = _numberOfClassesIncludingNan;
            bool useSoftmax = false;
 
            if (LightGbmTrainerOptions.UseSoftmax.HasValue)
                useSoftmax = LightGbmTrainerOptions.UseSoftmax.Value;
            else
            {
                if (labels.Length >= _minDataToUseSoftmax)
                    useSoftmax = true;
 
                ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.UseSoftmax) + " = " + useSoftmax);
            }
 
            if (useSoftmax)
                GbmOptions["objective"] = "multiclass";
            else
                GbmOptions["objective"] = "multiclassova";
        }
 
        private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            SchemaShape.Column labelCol = default;
            if (LightGbmTrainerOptions.LightGbmModel == null)
            {
                bool success = inputSchema.TryFindColumn(LabelColumn.Name, out labelCol);
                Contracts.Assert(success);
            }
 
            var metadata = LightGbmTrainerOptions.LightGbmModel == null ? new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues).Concat(AnnotationUtils.GetTrainerOutputAnnotation())) : new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation());
            return new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, metadata)
            };
        }
 
        private protected override MulticlassPredictionTransformer<OneVersusAllModelParameters> MakeTransformer(OneVersusAllModelParameters model, DataViewSchema trainSchema)
            => new MulticlassPredictionTransformer<OneVersusAllModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
 
        /// <summary>
        /// Trains a <see cref="LightGbmMulticlassTrainer"/> using both training and validation data, returns
        /// a <see cref="MulticlassPredictionTransformer{OneVsAllModelParameters}"/>.
        /// </summary>
        public MulticlassPredictionTransformer<OneVersusAllModelParameters> Fit(IDataView trainData, IDataView validationData)
            => TrainTransformer(trainData, validationData);
    }
 
    /// <summary>
    /// A component to train a LightGBM model.
    /// </summary>
    internal static partial class LightGbm
    {
        [TlcModule.EntryPoint(
            Name = "Trainers.LightGbmClassifier",
            Desc = "Train a LightGBM multi class model.",
            UserName = LightGbmMulticlassTrainer.Summary,
            ShortName = LightGbmMulticlassTrainer.ShortName)]
        public static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, LightGbmMulticlassTrainer.Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainLightGBM");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            return TrainerEntryPointsUtils.Train<LightGbmMulticlassTrainer.Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
                () => new LightGbmMulticlassTrainer(host, input),
                getLabel: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
                getWeight: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName));
        }
    }
}