|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Reflection.Emit;
using System.Runtime.InteropServices.ComTypes;
using System.Text;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.FastTree;
namespace Microsoft.ML.Trainers.LightGbm
{
[BestFriend]
internal static class Defaults
{
public const int NumberOfIterations = 100;
}
/// <summary>
/// Lock for LightGBM trainer.
/// </summary>
internal static class LightGbmShared
{
// Lock for the operations that are multi-threading inside in LightGBM DLL.
public static readonly object LockForMultiThreadingInside = new object();
// Lock for the sampling stage, this can reduce the peak memory usage.
public static readonly object SampleLock = new object();
}
/// <summary>
/// Base class for all training with LightGBM.
/// </summary>
public abstract class LightGbmTrainerBase<TOptions, TOutput, TTransformer, TModel> : TrainerEstimatorBaseWithGroupId<TTransformer, TModel>
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TModel : class
where TOptions : LightGbmTrainerBase<TOptions, TOutput, TTransformer, TModel>.OptionsBase, new()
{
public class OptionsBase : TrainerInputBaseWithGroupId
{
// Static override name map that maps friendly names to lightGBM arguments.
// If an argument is not here, then its name is identical to a lightGBM argument
// and does not require a mapping, for example, Subsample.
private protected static Dictionary<string, string> NameMapping = new Dictionary<string, string>()
{
{nameof(MinimumExampleCountPerLeaf), "min_data_per_leaf"},
{nameof(NumberOfLeaves), "num_leaves"},
{nameof(MaximumBinCountPerFeature), "max_bin" },
{nameof(MinimumExampleCountPerGroup), "min_data_per_group" },
{nameof(MaximumCategoricalSplitPointCount), "max_cat_threshold" },
{nameof(CategoricalSmoothing), "cat_smooth" },
{nameof(L2CategoricalRegularization), "cat_l2" },
{nameof(HandleMissingValue), "use_missing" },
{nameof(UseZeroAsMissingValue), "zero_as_missing" }
};
internal string GetOptionName(string name)
{
if (NameMapping.ContainsKey(name))
return NameMapping[name];
return LightGbmInterfaceUtils.GetOptionName(name);
}
private protected OptionsBase() { }
/// <summary>
/// The number of boosting iterations. A new tree is created in each iteration, so this is equivalent to the number of trees.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations.", SortOrder = 1, ShortName = "iter")]
[TGUI(Label = "Number of boosting iterations", SuggestedSweeps = "10,20,50,100,150,200")]
[TlcModule.SweepableDiscreteParam("NumBoostRound", new object[] { 10, 20, 50, 100, 150, 200 })]
public int NumberOfIterations = Defaults.NumberOfIterations;
/// <summary>
/// The shrinkage rate for trees, used to prevent over-fitting.
/// </summary>
/// <value>
/// Valid range is (0,1].
/// </value>
[Argument(ArgumentType.AtMostOnce,
HelpText = "Shrinkage rate for trees, used to prevent over-fitting. Range: (0,1].",
SortOrder = 2, ShortName = "lr", NullName = "<Auto>")]
[TGUI(Label = "Learning Rate", SuggestedSweeps = "0.025-0.4;log")]
[TlcModule.SweepableFloatParamAttribute("LearningRate", 0.025f, 0.4f, isLogScale: true)]
public double? LearningRate;
/// <summary>
/// The maximum number of leaves in one tree.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum leaves for trees.",
SortOrder = 2, ShortName = "nl", NullName = "<Auto>")]
[TGUI(Description = "The maximum number of leaves per tree", SuggestedSweeps = "2-128;log;inc:4")]
[TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale: true, stepSize: 4)]
public int? NumberOfLeaves;
/// <summary>
/// The minimal number of data points required to form a new tree leaf.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Minimum number of instances needed in a child.",
SortOrder = 2, ShortName = "mil", NullName = "<Auto>")]
[TGUI(Label = "Min Documents In Leaves", SuggestedSweeps = "1,10,20,50 ")]
[TlcModule.SweepableDiscreteParamAttribute("MinDataPerLeaf", new object[] { 1, 10, 20, 50 })]
public int? MinimumExampleCountPerLeaf;
/// <summary>
/// The maximum number of bins that feature values will be bucketed in.
/// </summary>
/// <remarks>
/// The small number of bins may reduce training accuracy but may increase general power (deal with over-fitting).
/// </remarks>
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of bucket bin for features.", ShortName = "mb")]
public int MaximumBinCountPerFeature = 255;
/// <summary>
/// Determines which booster to use.
/// </summary>
/// <value>
/// Available boosters are <see cref="DartBooster"/>, <see cref="GossBooster"/>, and <see cref="GradientBooster"/>.
/// </value>
[Argument(ArgumentType.Multiple,
HelpText = "Which booster to use, can be gbtree, gblinear or dart. gbtree and dart use tree based model while gblinear uses linear function.",
Name = "Booster",
SortOrder = 3)]
internal IBoosterParameterFactory BoosterFactory = new GradientBooster.Options();
/// <summary>
/// Determines whether to output progress status during training and evaluation.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Verbose", ShortName = "v")]
public bool Verbose = false;
/// <summary>
/// Controls the logging level in LighGBM.
/// </summary>
/// <value>
/// <see langword="true"/> means only output Fatal errors. <see langword="false"/> means output Fatal, Warning, and Info level messages.
/// </value>
[Argument(ArgumentType.AtMostOnce, HelpText = "Printing running messages.")]
public bool Silent = true;
/// <summary>
/// Determines the number of threads used to run LightGBM.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of parallel threads used to run LightGBM.", ShortName = "nt")]
public int? NumberOfThreads;
/// <summary>
/// Determines the number of rounds, after which training will stop if validation metric doesn't improve.
/// </summary>
/// <value>
/// 0 means disable early stopping.
/// </value>
[Argument(ArgumentType.AtMostOnce, HelpText = "Rounds of early stopping, 0 will disable it.",
ShortName = "es")]
public int EarlyStoppingRound = 0;
/// <summary>
/// Number of data points per batch, when loading data.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of entries in a batch when loading data.", Hide = true)]
public int BatchSize = 1 << 20;
/// <summary>
/// Whether to enable categorical split or not.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Enable categorical split or not.", ShortName = "cat")]
[TlcModule.SweepableDiscreteParam("UseCat", new object[] { true, false })]
public bool? UseCategoricalSplit;
/// <summary>
/// Whether to enable special handling of missing value or not.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Enable special handling of missing value or not.", ShortName = "hmv")]
[TlcModule.SweepableDiscreteParam("UseMissing", new object[] { true, false })]
public bool HandleMissingValue = true;
/// <summary>
/// Whether to enable the usage of zero (0) as missing value.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Enable usage of zero (0) as missing value.", ShortName = "uzam")]
[TlcModule.SweepableDiscreteParam("UseZeroAsMissing", new object[] { true, false })]
public bool UseZeroAsMissingValue = false;
/// <summary>
/// The minimum number of data points per categorical group.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Minimum number of instances per categorical group.", ShortName = "mdpg")]
[TlcModule.Range(Inf = 0, Max = int.MaxValue)]
[TlcModule.SweepableDiscreteParam("MinDataPerGroup", new object[] { 10, 50, 100, 200 })]
public int MinimumExampleCountPerGroup = 100;
/// <summary>
/// Maximum categorical split points to consider when splitting on a categorical feature.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Max number of categorical thresholds.", ShortName = "maxcat")]
[TlcModule.Range(Inf = 0, Max = int.MaxValue)]
[TlcModule.SweepableDiscreteParam("MaxCatThreshold", new object[] { 8, 16, 32, 64 })]
public int MaximumCategoricalSplitPointCount = 32;
/// <summary>
/// Laplace smooth term in categorical feature split.
/// This can reduce the effect of noises in categorical features, especially for categories with few data.
/// </summary>
/// <value>
/// Constraints: <see cref="CategoricalSmoothing"/> >= 0.0
/// </value>
[Argument(ArgumentType.AtMostOnce, HelpText = "Lapalace smooth term in categorical feature spilt. Avoid the bias of small categories.")]
[TlcModule.Range(Min = 0.0)]
[TlcModule.SweepableDiscreteParam("CatSmooth", new object[] { 1, 10, 20 })]
public double CategoricalSmoothing = 10;
/// <summary>
/// L2 regularization for categorical split.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization for categorical split.")]
[TlcModule.Range(Min = 0.0)]
[TlcModule.SweepableDiscreteParam("CatL2", new object[] { 0.1, 0.5, 1, 5, 10 })]
public double L2CategoricalRegularization = 10;
/// <summary>
/// The random seed for LightGBM to use.
/// </summary>
/// <value>
/// If not specified, <see cref="MLContext"/> will generate a random seed to be used.
/// </value>
[Argument(ArgumentType.AtMostOnce, HelpText = "Sets the random seed for LightGBM to use.")]
public int? Seed;
[Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")]
internal ISupportParallel ParallelTrainer = new SingleTrainerFactory();
private BoosterParameterBase.OptionsBase _boosterParameter;
internal Stream LightGbmModel = null;
/// <summary>
/// Booster parameter to use
/// </summary>
public BoosterParameterBase.OptionsBase Booster
{
get => _boosterParameter;
set
{
_boosterParameter = value;
BoosterFactory = _boosterParameter;
}
}
internal virtual Dictionary<string, object> ToDictionary(IHost host)
{
Contracts.CheckValue(host, nameof(host));
Dictionary<string, object> res = new Dictionary<string, object>();
var boosterParams = BoosterFactory.CreateComponent(host);
boosterParams.UpdateParameters(res);
res["boosting_type"] = boosterParams.BoosterName;
res["verbose"] = Silent ? "-1" : "1";
if (NumberOfThreads.HasValue)
res["nthread"] = NumberOfThreads.Value;
res["seed"] = (Seed.HasValue) ? Seed : host.Rand.Next();
res[GetOptionName(nameof(MaximumBinCountPerFeature))] = MaximumBinCountPerFeature;
res[GetOptionName(nameof(HandleMissingValue))] = HandleMissingValue;
res[GetOptionName(nameof(UseZeroAsMissingValue))] = UseZeroAsMissingValue;
res[GetOptionName(nameof(MinimumExampleCountPerGroup))] = MinimumExampleCountPerGroup;
res[GetOptionName(nameof(MaximumCategoricalSplitPointCount))] = MaximumCategoricalSplitPointCount;
res[GetOptionName(nameof(CategoricalSmoothing))] = CategoricalSmoothing;
res[GetOptionName(nameof(L2CategoricalRegularization))] = L2CategoricalRegularization;
return res;
}
}
private sealed class CategoricalMetaData
{
public int NumCol;
public int TotalCats;
public int[] CategoricalBoudaries;
public int[] OnehotIndices;
public int[] OnehotBias;
public bool[] IsCategoricalFeature;
public int[] CatIndices;
}
// Contains the passed in options when the API is called
private protected readonly TOptions LightGbmTrainerOptions;
/// <summary>
/// Stores arguments as objects to convert them to invariant string type in the end so that
/// the code is culture agnostic. When retrieving key value from this dictionary as string
/// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]).
/// </summary>
private protected readonly Dictionary<string, object> GbmOptions;
private protected readonly IParallel ParallelTraining;
// Store _featureCount and _trainedEnsemble to construct predictor.
private protected int FeatureCount;
private protected InternalTreeEnsemble TrainedEnsemble;
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true);
public override TrainerInfo Info => _info;
private protected LightGbmTrainerBase(IHostEnvironment env,
string name,
SchemaShape.Column labelColumn,
string featureColumnName,
string exampleWeightColumnName,
string rowGroupColumnName,
int? numberOfLeaves,
int? minimumExampleCountPerLeaf,
double? learningRate,
int numberOfIterations)
: this(env, name, new TOptions()
{
NumberOfLeaves = numberOfLeaves,
MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf,
LearningRate = learningRate,
NumberOfIterations = numberOfIterations,
LabelColumnName = labelColumn.Name,
FeatureColumnName = featureColumnName,
ExampleWeightColumnName = exampleWeightColumnName,
RowGroupColumnName = rowGroupColumnName
},
labelColumn)
{
}
private protected LightGbmTrainerBase(IHostEnvironment env, string name, TOptions options, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName), label,
TrainerUtils.MakeR4ScalarWeightColumn(options.ExampleWeightColumnName), TrainerUtils.MakeU4ScalarColumn(options.RowGroupColumnName))
{
Host.CheckValue(options, nameof(options));
Contracts.CheckUserArg(options.NumberOfIterations >= 0, nameof(options.NumberOfIterations), "must be >= 0.");
Contracts.CheckUserArg(options.MaximumBinCountPerFeature > 0, nameof(options.MaximumBinCountPerFeature), "must be > 0.");
Contracts.CheckUserArg(options.MinimumExampleCountPerGroup > 0, nameof(options.MinimumExampleCountPerGroup), "must be > 0.");
Contracts.CheckUserArg(options.MaximumCategoricalSplitPointCount > 0, nameof(options.MaximumCategoricalSplitPointCount), "must be > 0.");
Contracts.CheckUserArg(options.CategoricalSmoothing >= 0, nameof(options.CategoricalSmoothing), "must be >= 0.");
Contracts.CheckUserArg(options.L2CategoricalRegularization >= 0.0, nameof(options.L2CategoricalRegularization), "must be >= 0.");
LightGbmTrainerOptions = options;
ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer();
GbmOptions = LightGbmTrainerOptions.ToDictionary(Host);
InitParallelTraining();
}
private protected override TModel TrainModelCore(TrainContext context)
{
InitializeBeforeTraining();
Host.CheckValue(context, nameof(context));
Dataset dtrain = null;
Dataset dvalid = null;
try
{
if (LightGbmTrainerOptions.LightGbmModel != null)
{
LightGbmTrainerOptions.LightGbmModel.Position = 0;
using (var ch = Host.Start("Loading LightGBM model file"))
{
StreamReader reader = new StreamReader(LightGbmTrainerOptions.LightGbmModel);
string modelText = reader.ReadToEnd();
AdditionalLoadPreTrainedModel(modelText);
// Load objective into options
string[] lines = modelText.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
// Jump to the "objective" value in the file. It's at the beginning.
int i = 0;
while (!lines[i].StartsWith("objective"))
i++;
// Format in the file is objective=multiclass num_class:4
var split = lines[i].Split(' ');
GbmOptions["objective"] = split[0].Split('=')[1];
var modelParameters = Booster.GetParameters(modelText);
// Going to set the parameters via reflection so that we don't have manually set them on the options object
Type optionsType = LightGbmTrainerOptions.GetType();
var optionsFields = optionsType.GetFields(System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.FlattenHierarchy | System.Reflection.BindingFlags.Instance);
foreach (var field in optionsFields)
{
var lightGbmName = LightGbmTrainerOptions.GetOptionName(field.Name);
if (modelParameters.ContainsKey(lightGbmName))
{
if (field.FieldType.Name.StartsWith("Nullable"))
{
if (field.FieldType.GenericTypeArguments[0] == typeof(double))
{
field.SetValue(LightGbmTrainerOptions, Double.Parse(modelParameters[lightGbmName]));
}
else if (field.FieldType.GenericTypeArguments[0] == typeof(int))
{
field.SetValue(LightGbmTrainerOptions, int.Parse(modelParameters[lightGbmName]));
}
else if (field.FieldType.GenericTypeArguments[0] == typeof(float))
{
field.SetValue(LightGbmTrainerOptions, float.Parse(modelParameters[lightGbmName]));
}
// TODO: throw for unknown type
}
else if (field.FieldType.Name.StartsWith("Boolean"))
{
if (modelParameters[lightGbmName] == "1")
field.SetValue(LightGbmTrainerOptions, true);
else
field.SetValue(LightGbmTrainerOptions, false);
}
else
field.SetValue(LightGbmTrainerOptions, Convert.ChangeType(modelParameters[lightGbmName], field.FieldType));
}
}
var catBoundaries = !String.IsNullOrEmpty(modelParameters["categorical_feature"]) ? modelParameters["categorical_feature"].Split(',').Select(x => int.Parse(x, CultureInfo.InvariantCulture)).ToArray() : null;
TrainedEnsemble = Booster.GetModel(catBoundaries, modelText);
FeatureCount = Booster.GetNumFeatures(modelText);
}
}
else
{
CategoricalMetaData catMetaData;
using (var ch = Host.Start("Loading data for LightGBM"))
{
using (var pch = Host.StartProgressChannel("Loading data for LightGBM"))
{
dtrain = LoadTrainingData(ch, context.TrainingSet, out catMetaData);
if (context.ValidationSet != null)
dvalid = LoadValidationData(ch, dtrain, context.ValidationSet, catMetaData);
}
}
using (var ch = Host.Start("Training with LightGBM"))
{
using (var pch = Host.StartProgressChannel("Training with LightGBM"))
TrainCore(ch, pch, dtrain, catMetaData, dvalid);
}
}
}
finally
{
dtrain?.Dispose();
dvalid?.Dispose();
DisposeParallelTraining();
}
return CreatePredictor();
}
private protected virtual void InitializeBeforeTraining() { }
// For loading addtional info when we are loading a pre-trained model.
private protected virtual void AdditionalLoadPreTrainedModel(string modelText) { }
private void InitParallelTraining()
{
if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1)
{
GbmOptions["tree_learner"] = ParallelTraining.ParallelType();
var otherParams = ParallelTraining.AdditionalParams();
if (otherParams != null)
{
foreach (var pair in otherParams)
GbmOptions[pair.Key] = pair.Value;
}
Contracts.CheckValue(ParallelTraining.GetReduceScatterFunction(), nameof(ParallelTraining.GetReduceScatterFunction));
Contracts.CheckValue(ParallelTraining.GetAllgatherFunction(), nameof(ParallelTraining.GetAllgatherFunction));
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.NetworkInitWithFunctions(
ParallelTraining.NumMachines(),
ParallelTraining.Rank(),
ParallelTraining.GetReduceScatterFunction(),
ParallelTraining.GetAllgatherFunction()
));
}
}
private void DisposeParallelTraining()
{
if (ParallelTraining.NumMachines() > 1)
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.NetworkFree());
}
private protected virtual void CheckDataValid(IChannel ch, RoleMappedData data)
{
data.CheckFeatureFloatVector();
// If we are loading a pre-trained model we don't need a label column
if (LightGbmTrainerOptions.LightGbmModel == null)
ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "Need a label column");
}
private protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategorical, int totalCats, bool hiddenMsg = false)
{
double learningRate = LightGbmTrainerOptions.LearningRate ?? DefaultLearningRate(numRow, hasCategorical, totalCats);
int numberOfLeaves = LightGbmTrainerOptions.NumberOfLeaves ?? DefaultNumLeaves(numRow, hasCategorical, totalCats);
int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, 1);
GbmOptions["learning_rate"] = learningRate;
GbmOptions["num_leaves"] = numberOfLeaves;
GbmOptions["min_data_per_leaf"] = minimumExampleCountPerLeaf;
if (!hiddenMsg)
{
if (!LightGbmTrainerOptions.LearningRate.HasValue)
ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.LearningRate) + " = " + learningRate);
if (!LightGbmTrainerOptions.NumberOfLeaves.HasValue)
ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.NumberOfLeaves) + " = " + numberOfLeaves);
if (!LightGbmTrainerOptions.MinimumExampleCountPerLeaf.HasValue)
ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.MinimumExampleCountPerLeaf) + " = " + minimumExampleCountPerLeaf);
}
}
[BestFriend]
internal Dictionary<string, object> GetGbmParameters() => GbmOptions;
private FloatLabelCursor.Factory CreateCursorFactory(RoleMappedData data)
{
var loadFlags = CursOpt.AllLabels | CursOpt.AllFeatures;
if (PredictionKind == PredictionKind.Ranking)
loadFlags |= CursOpt.Group;
if (data.Schema.Weight.HasValue)
loadFlags |= CursOpt.AllWeights;
var factory = new FloatLabelCursor.Factory(data, loadFlags);
return factory;
}
private static List<int> GetCategoricalBoundires(int[] categoricalFeatures, int rawNumCol)
{
List<int> catBoundaries = new List<int> { 0 };
int curFidx = 0;
int j = 0;
while (curFidx < rawNumCol)
{
if (j < categoricalFeatures.Length && curFidx == categoricalFeatures[j])
{
if (curFidx > catBoundaries[catBoundaries.Count - 1])
catBoundaries.Add(curFidx);
if (categoricalFeatures[j + 1] - categoricalFeatures[j] >= 0)
{
curFidx = categoricalFeatures[j + 1] + 1;
catBoundaries.Add(curFidx);
}
else
{
for (int i = curFidx + 1; i <= categoricalFeatures[j + 1] + 1; ++i)
catBoundaries.Add(i);
curFidx = categoricalFeatures[j + 1] + 1;
}
j += 2;
}
else
{
catBoundaries.Add(curFidx + 1);
++curFidx;
}
}
return catBoundaries;
}
private static List<string> ConstructCategoricalFeatureMetaData(int[] categoricalFeatures, int rawNumCol, ref CategoricalMetaData catMetaData)
{
List<int> catBoundaries = GetCategoricalBoundires(categoricalFeatures, rawNumCol);
catMetaData.NumCol = catBoundaries.Count - 1;
catMetaData.CategoricalBoudaries = catBoundaries.ToArray();
catMetaData.IsCategoricalFeature = new bool[catMetaData.NumCol];
catMetaData.OnehotIndices = new int[rawNumCol];
catMetaData.OnehotBias = new int[rawNumCol];
List<string> catIndices = new List<string>();
int j = 0;
for (int i = 0; i < catMetaData.NumCol; ++i)
{
var numCat = catMetaData.CategoricalBoudaries[i + 1] - catMetaData.CategoricalBoudaries[i];
if (numCat > 1)
{
catMetaData.TotalCats += numCat;
catMetaData.IsCategoricalFeature[i] = true;
catIndices.Add(i.ToString());
for (int k = catMetaData.CategoricalBoudaries[i]; k < catMetaData.CategoricalBoudaries[i + 1]; ++k)
{
catMetaData.OnehotIndices[j] = i;
catMetaData.OnehotBias[j] = k - catMetaData.CategoricalBoudaries[i];
++j;
}
}
else
{
catMetaData.IsCategoricalFeature[i] = false;
catMetaData.OnehotIndices[j] = i;
catMetaData.OnehotBias[j] = 0;
++j;
}
}
catMetaData.CatIndices = catIndices.Select(int.Parse).ToArray();
return catIndices;
}
private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData trainData, int numRow)
{
CategoricalMetaData catMetaData = new CategoricalMetaData();
int[] categoricalFeatures = null;
const int useCatThreshold = 50000;
// Disable cat when data is too small, reduce the overfitting.
bool useCat = LightGbmTrainerOptions.UseCategoricalSplit ?? numRow > useCatThreshold;
if (!LightGbmTrainerOptions.UseCategoricalSplit.HasValue)
ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.UseCategoricalSplit) + " = " + useCat);
if (useCat)
{
var featureCol = trainData.Schema.Feature.Value;
AnnotationUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, featureCol.Index, out categoricalFeatures);
}
var colType = trainData.Schema.Feature.Value.Type;
int rawNumCol = colType.GetVectorSize();
FeatureCount = rawNumCol;
catMetaData.TotalCats = 0;
if (categoricalFeatures == null)
{
catMetaData.CategoricalBoudaries = null;
catMetaData.NumCol = rawNumCol;
}
else
{
var catIndices = ConstructCategoricalFeatureMetaData(categoricalFeatures, rawNumCol, ref catMetaData);
// Set categorical features
GbmOptions["categorical_feature"] = string.Join(",", catIndices);
}
return catMetaData;
}
private Dataset LoadTrainingData(IChannel ch, RoleMappedData trainData, out CategoricalMetaData catMetaData)
{
// Verifications.
Host.AssertValue(ch);
ch.CheckValue(trainData, nameof(trainData));
CheckDataValid(ch, trainData);
// Load metadata first.
var factory = CreateCursorFactory(trainData);
GetMetainfo(ch, factory, out int numRow, out float[] labels, out float[] weights, out int[] groups);
catMetaData = GetCategoricalMetaData(ch, trainData, numRow);
GetDefaultParameters(ch, numRow, catMetaData.CategoricalBoudaries != null, catMetaData.TotalCats);
CheckAndUpdateParametersBeforeTraining(ch, trainData, labels, groups);
string param = LightGbmInterfaceUtils.JoinParameters(GbmOptions);
Dataset dtrain;
// To reduce peak memory usage, only enable one sampling task at any given time.
lock (LightGbmShared.SampleLock)
{
CreateDatasetFromSamplingData(ch, factory, numRow,
param, labels, weights, groups, catMetaData, out dtrain);
}
// Push rows into dataset.
LoadDataset(ch, factory, dtrain, numRow, LightGbmTrainerOptions.BatchSize, catMetaData);
return dtrain;
}
private Dataset LoadValidationData(IChannel ch, Dataset dtrain, RoleMappedData validData, CategoricalMetaData catMetaData)
{
// Verifications.
Host.AssertValue(ch);
ch.CheckValue(validData, nameof(validData));
CheckDataValid(ch, validData);
// Load meta info first.
var factory = CreateCursorFactory(validData);
GetMetainfo(ch, factory, out int numRow, out float[] labels, out float[] weights, out int[] groups);
// Construct validation dataset.
Dataset dvalid = new Dataset(dtrain, numRow, labels, weights, groups);
// Push rows into dataset.
LoadDataset(ch, factory, dvalid, numRow, LightGbmTrainerOptions.BatchSize, catMetaData);
return dvalid;
}
private void TrainCore(IChannel ch, IProgressChannel pch, Dataset dtrain, CategoricalMetaData catMetaData, Dataset dvalid = null)
{
Host.AssertValue(ch);
Host.AssertValue(pch);
Host.AssertValue(dtrain);
Host.AssertValueOrNull(dvalid);
Host.CheckAlive();
// For multi class, the number of labels is required.
ch.Assert(((ITrainer)this).PredictionKind != PredictionKind.MulticlassClassification || GbmOptions.ContainsKey("num_class"),
"LightGBM requires the number of classes to be specified in the parameters.");
// Only enable one trainer to run at one time.
lock (LightGbmShared.LockForMultiThreadingInside)
{
ch.Info("LightGBM objective={0}", GbmOptions["objective"]);
using (Booster bst = WrappedLightGbmTraining.Train(Host, ch, pch, GbmOptions, dtrain,
dvalid: dvalid, numIteration: LightGbmTrainerOptions.NumberOfIterations,
verboseEval: LightGbmTrainerOptions.Verbose, earlyStoppingRound: LightGbmTrainerOptions.EarlyStoppingRound))
{
TrainedEnsemble = Booster.GetModel(catMetaData.CategoricalBoudaries, bst.GetModelString());
}
}
}
/// <summary>
/// Calculate the density of data. Only use top 1000 rows to calculate.
/// </summary>
private static double DetectDensity(FloatLabelCursor.Factory factory, int numRows = 1000)
{
int nonZeroCount = 0;
int totalCount = 0;
using (var cursor = factory.Create())
{
while (cursor.MoveNext() && numRows > 0)
{
nonZeroCount += cursor.Features.GetValues().Length;
totalCount += cursor.Features.Length;
--numRows;
}
}
return (double)nonZeroCount / totalCount;
}
/// <summary>
/// Compute row count, list of labels, weights and group counts of the dataset.
/// </summary>
private void GetMetainfo(IChannel ch, FloatLabelCursor.Factory factory,
out int numRow, out float[] labels, out float[] weights, out int[] groups)
{
ch.Check(factory.Data.Schema.Label != null, "The data should have label.");
List<float> labelList = new List<float>();
bool hasWeights = factory.Data.Schema.Weight != null;
bool hasGroup = false;
if (PredictionKind == PredictionKind.Ranking)
{
ch.Check(factory.Data.Schema.Group != null, "The data for ranking task should have group field.");
hasGroup = true;
}
List<float> weightList = hasWeights ? new List<float>() : null;
List<ulong> cursorGroups = hasGroup ? new List<ulong>() : null;
using (var cursor = factory.Create())
{
while (cursor.MoveNext())
{
if (labelList.Count == Utils.ArrayMaxSize)
throw ch.Except($"Dataset row count exceeded the maximum count of {Utils.ArrayMaxSize}");
labelList.Add(cursor.Label);
if (hasWeights)
{
// Default weight = 1.
if (float.IsNaN(cursor.Weight))
weightList.Add(1);
else
weightList.Add(cursor.Weight);
}
if (hasGroup)
cursorGroups.Add(cursor.Group);
}
}
labels = labelList.ToArray();
ConvertNaNLabels(ch, factory.Data, labels);
numRow = labels.Length;
ch.Check(numRow > 0, "Cannot use empty dataset.");
weights = hasWeights ? weightList.ToArray() : null;
groups = null;
if (hasGroup)
{
List<int> groupList = new List<int>();
int lastGroup = -1;
for (int i = 0; i < numRow; ++i)
{
if (i == 0 || cursorGroups[i] != cursorGroups[i - 1])
{
groupList.Add(1);
++lastGroup;
}
else
++groupList[lastGroup];
}
groups = groupList.ToArray();
}
}
/// <summary>
/// Convert Nan labels. Default way is converting them to zero.
/// </summary>
private protected virtual void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels)
{
for (int i = 0; i < labels.Length; ++i)
{
if (float.IsNaN(labels[i]))
labels[i] = 0;
}
}
private static bool MoveMany(FloatLabelCursor cursor, long count)
{
for (long i = 0; i < count; ++i)
{
if (!cursor.MoveNext())
return false;
}
return true;
}
private void GetFeatureValueDense(IChannel ch, FloatLabelCursor cursor, CategoricalMetaData catMetaData, Random rand, out ReadOnlySpan<float> featureValues)
{
var cursorFeaturesValues = cursor.Features.GetValues();
if (catMetaData.CategoricalBoudaries != null)
{
float[] featureValuesTemp = new float[catMetaData.NumCol];
for (int i = 0; i < catMetaData.NumCol; ++i)
{
float fv = cursorFeaturesValues[catMetaData.CategoricalBoudaries[i]];
if (catMetaData.IsCategoricalFeature[i])
{
int hotIdx = catMetaData.CategoricalBoudaries[i] - 1;
int nhot = 0;
for (int j = catMetaData.CategoricalBoudaries[i]; j < catMetaData.CategoricalBoudaries[i + 1]; ++j)
{
if (cursorFeaturesValues[j] > 0)
{
// Reservoir Sampling.
nhot++;
var prob = rand.NextSingle();
if (prob < 1.0f / nhot)
hotIdx = j;
}
}
// All-Zero is category 0.
fv = hotIdx - catMetaData.CategoricalBoudaries[i];
}
featureValuesTemp[i] = fv;
}
featureValues = featureValuesTemp;
}
else
{
featureValues = cursorFeaturesValues;
}
}
private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor,
CategoricalMetaData catMetaData, Random rand, out ReadOnlySpan<int> indices,
out ReadOnlySpan<float> featureValues, out int cnt)
{
var cursorFeaturesValues = cursor.Features.GetValues();
var cursorFeaturesIndices = cursor.Features.GetIndices();
if (catMetaData.CategoricalBoudaries != null)
{
Dictionary<int, float> ivPair = new Dictionary<int, float>();
foreach (var idx in catMetaData.CatIndices)
ivPair[idx] = -1;
int lastIdx = -1;
int nhot = 0;
for (int i = 0; i < cursorFeaturesValues.Length; ++i)
{
float fv = cursorFeaturesValues[i];
int colIdx = cursorFeaturesIndices[i];
int newColIdx = catMetaData.OnehotIndices[colIdx];
if (catMetaData.IsCategoricalFeature[newColIdx])
fv = catMetaData.OnehotBias[colIdx];
if (newColIdx != lastIdx)
{
ivPair[newColIdx] = fv;
nhot = 1;
}
else
{
// Multi-hot.
++nhot;
var prob = rand.NextSingle();
if (prob < 1.0f / nhot)
ivPair[newColIdx] = fv;
}
lastIdx = newColIdx;
}
var sortedIVPair = new SortedDictionary<int, float>(ivPair);
indices = sortedIVPair.Keys.ToArray();
featureValues = sortedIVPair.Values.ToArray();
cnt = ivPair.Count;
}
else
{
indices = cursorFeaturesIndices;
featureValues = cursorFeaturesValues;
cnt = cursorFeaturesValues.Length;
}
}
/// <summary>
/// Create a dataset from the sampling data.
/// </summary>
private void CreateDatasetFromSamplingData(IChannel ch, FloatLabelCursor.Factory factory,
int numRow, string param, float[] labels, float[] weights, int[] groups, CategoricalMetaData catMetaData,
out Dataset dataset)
{
Host.AssertValue(ch);
int numSampleRow = GetNumSampleRow(numRow, FeatureCount);
var rand = Host.Rand;
double averageStep = (double)numRow / numSampleRow;
int totalIdx = 0;
int sampleIdx = 0;
double density = DetectDensity(factory);
double[][] sampleValuePerColumn = new double[catMetaData.NumCol][];
int[][] sampleIndicesPerColumn = new int[catMetaData.NumCol][];
int[] nonZeroCntPerColumn = new int[catMetaData.NumCol];
int estimateNonZeroCnt = (int)(numSampleRow * density);
estimateNonZeroCnt = Math.Max(1, estimateNonZeroCnt);
for (int i = 0; i < catMetaData.NumCol; i++)
{
nonZeroCntPerColumn[i] = 0;
sampleValuePerColumn[i] = new double[estimateNonZeroCnt];
sampleIndicesPerColumn[i] = new int[estimateNonZeroCnt];
}
using (var cursor = factory.Create())
{
int step = 1;
if (averageStep > 1)
step = rand.Next((int)(2 * averageStep - 1)) + 1;
while (MoveMany(cursor, step))
{
if (cursor.Features.IsDense)
{
GetFeatureValueDense(ch, cursor, catMetaData, rand, out ReadOnlySpan<float> featureValues);
for (int i = 0; i < catMetaData.NumCol; ++i)
{
float fv = featureValues[i];
if (fv == 0)
continue;
int curNonZeroCnt = nonZeroCntPerColumn[i];
Utils.EnsureSize(ref sampleValuePerColumn[i], curNonZeroCnt + 1);
Utils.EnsureSize(ref sampleIndicesPerColumn[i], curNonZeroCnt + 1);
// sampleValuePerColumn[i] is a vector whose j-th element is added when j-th non-zero value
// at the i-th feature is found as scanning the training data.
// In other words, sampleValuePerColumn[i][j] is the j-th non-zero i-th feature in the data set.
// when we scan the data matrix example-by-example.
sampleValuePerColumn[i][curNonZeroCnt] = fv;
// If the data set is dense, sampleValuePerColumn[i][j] would be the i-th feature at the j-th example.
// If the data set is not dense, sampleValuePerColumn[i][j] would be the i-th feature at the
// sampleIndicesPerColumn[i][j]-th example.
sampleIndicesPerColumn[i][curNonZeroCnt] = sampleIdx;
// The number of non-zero values at the i-th feature is nonZeroCntPerColumn[i].
nonZeroCntPerColumn[i] = curNonZeroCnt + 1;
}
}
else
{
GetFeatureValueSparse(ch, cursor, catMetaData, rand, out ReadOnlySpan<int> featureIndices, out ReadOnlySpan<float> featureValues, out int cnt);
for (int i = 0; i < cnt; ++i)
{
int colIdx = featureIndices[i];
float fv = featureValues[i];
if (fv == 0)
continue;
int curNonZeroCnt = nonZeroCntPerColumn[colIdx];
Utils.EnsureSize(ref sampleValuePerColumn[colIdx], curNonZeroCnt + 1);
Utils.EnsureSize(ref sampleIndicesPerColumn[colIdx], curNonZeroCnt + 1);
sampleValuePerColumn[colIdx][curNonZeroCnt] = fv;
sampleIndicesPerColumn[colIdx][curNonZeroCnt] = sampleIdx;
nonZeroCntPerColumn[colIdx] = curNonZeroCnt + 1;
}
}
// Actual row indexed sampled from the original data set
totalIdx += step;
// Row index in the sub-sampled data created in this loop.
++sampleIdx;
if (numSampleRow == sampleIdx || numRow == totalIdx)
break;
averageStep = (double)(numRow - totalIdx) / (numSampleRow - sampleIdx);
step = 1;
if (averageStep > 1)
step = rand.Next((int)(2 * averageStep - 1)) + 1;
}
}
dataset = new Dataset(sampleValuePerColumn, sampleIndicesPerColumn, catMetaData.NumCol, nonZeroCntPerColumn, sampleIdx, numRow, param, labels, weights, groups);
}
/// <summary>
/// Load dataset. Use row batch way to reduce peak memory cost.
/// </summary>
private void LoadDataset(IChannel ch, FloatLabelCursor.Factory factory, Dataset dataset, int numRow, int batchSize, CategoricalMetaData catMetaData)
{
Host.AssertValue(ch);
ch.AssertValue(factory);
ch.AssertValue(dataset);
ch.Assert(dataset.GetNumRows() == numRow);
ch.Assert(dataset.GetNumCols() == catMetaData.NumCol);
var rand = Host.Rand;
// To avoid array resize, batch size should bigger than size of one row.
batchSize = Math.Max(batchSize, catMetaData.NumCol);
double density = DetectDensity(factory);
int numElem = 0;
int totalRowCount = 0;
int curRowCount = 0;
if (density >= 0.5)
{
int batchRow = batchSize / catMetaData.NumCol;
batchRow = Math.Max(1, batchRow);
if (batchRow > numRow)
batchRow = numRow;
// This can only happen if the size of ONE example(row) exceeds the max array size. This looks like a very unlikely case.
if ((long)catMetaData.NumCol * batchRow > Utils.ArrayMaxSize)
throw ch.Except("Size of array exceeded the " + nameof(Utils.ArrayMaxSize));
float[] features = new float[catMetaData.NumCol * batchRow];
using (var cursor = factory.Create())
{
while (cursor.MoveNext())
{
ch.Assert(totalRowCount < numRow);
CopyToArray(ch, cursor, features, catMetaData, rand, ref numElem);
++totalRowCount;
++curRowCount;
if (batchRow == curRowCount)
{
ch.Assert(numElem == curRowCount * catMetaData.NumCol);
// PushRows is run by multi-threading inside, so lock here.
lock (LightGbmShared.LockForMultiThreadingInside)
dataset.PushRows(features, curRowCount, catMetaData.NumCol, totalRowCount - curRowCount);
curRowCount = 0;
numElem = 0;
}
}
ch.Assert(totalRowCount == numRow);
if (curRowCount > 0)
{
ch.Assert(numElem == curRowCount * catMetaData.NumCol);
// PushRows is run by multi-threading inside, so lock here.
lock (LightGbmShared.LockForMultiThreadingInside)
dataset.PushRows(features, curRowCount, catMetaData.NumCol, totalRowCount - curRowCount);
}
}
}
else
{
int esimateBatchRow = (int)(batchSize / (catMetaData.NumCol * density));
esimateBatchRow = Math.Max(1, esimateBatchRow);
float[] features = new float[batchSize];
int[] indices = new int[batchSize];
int[] indptr = new int[esimateBatchRow + 1];
using (var cursor = factory.Create())
{
while (cursor.MoveNext())
{
ch.Assert(totalRowCount < numRow);
// Need push rows to LightGBM.
if (numElem + cursor.Features.GetValues().Length > features.Length)
{
// Mini batch size is greater than size of one row.
// So, at least we have the data of one row.
ch.Assert(curRowCount > 0);
Utils.EnsureSize(ref indptr, curRowCount + 1);
indptr[curRowCount] = numElem;
// PushRows is run by multi-threading inside, so lock here.
lock (LightGbmShared.LockForMultiThreadingInside)
{
dataset.PushRows(indptr, indices, features,
curRowCount + 1, numElem, catMetaData.NumCol, totalRowCount - curRowCount);
}
curRowCount = 0;
numElem = 0;
}
Utils.EnsureSize(ref indptr, curRowCount + 1);
indptr[curRowCount] = numElem;
CopyToCsr(ch, cursor, indices, features, catMetaData, rand, ref numElem);
++totalRowCount;
++curRowCount;
}
ch.Assert(totalRowCount == numRow);
if (curRowCount > 0)
{
Utils.EnsureSize(ref indptr, curRowCount + 1);
indptr[curRowCount] = numElem;
// PushRows is run by multi-threading inside, so lock here.
lock (LightGbmShared.LockForMultiThreadingInside)
{
dataset.PushRows(indptr, indices, features, curRowCount + 1,
numElem, catMetaData.NumCol, totalRowCount - curRowCount);
}
}
}
}
}
private void CopyToArray(IChannel ch, FloatLabelCursor cursor, float[] features, CategoricalMetaData catMetaData, Random rand, ref int numElem)
{
ch.Assert(features.Length >= numElem + catMetaData.NumCol);
if (catMetaData.CategoricalBoudaries != null)
{
if (cursor.Features.IsDense)
{
GetFeatureValueDense(ch, cursor, catMetaData, rand, out ReadOnlySpan<float> featureValues);
for (int i = 0; i < catMetaData.NumCol; ++i)
features[numElem + i] = featureValues[i];
numElem += catMetaData.NumCol;
}
else
{
GetFeatureValueSparse(ch, cursor, catMetaData, rand, out ReadOnlySpan<int> indices, out ReadOnlySpan<float> featureValues, out int cnt);
int lastIdx = 0;
for (int i = 0; i < cnt; i++)
{
int slot = indices[i];
float fv = featureValues[i];
Contracts.Assert(slot >= lastIdx);
while (lastIdx < slot)
features[numElem + lastIdx++] = 0.0f;
Contracts.Assert(lastIdx == slot);
features[numElem + lastIdx++] = fv;
}
while (lastIdx < catMetaData.NumCol)
features[numElem + lastIdx++] = 0.0f;
numElem += catMetaData.NumCol;
}
}
else
{
cursor.Features.CopyTo(features, numElem, 0.0f);
numElem += catMetaData.NumCol;
}
}
private void CopyToCsr(IChannel ch, FloatLabelCursor cursor,
int[] indices, float[] features, CategoricalMetaData catMetaData, Random rand, ref int numElem)
{
int numValue = cursor.Features.GetValues().Length;
if (numValue > 0)
{
ch.Assert(indices.Length >= numElem + numValue);
ch.Assert(features.Length >= numElem + numValue);
if (cursor.Features.IsDense)
{
GetFeatureValueDense(ch, cursor, catMetaData, rand, out ReadOnlySpan<float> featureValues);
for (int i = 0; i < catMetaData.NumCol; ++i)
{
float fv = featureValues[i];
if (fv == 0)
continue;
features[numElem] = fv;
indices[numElem] = i;
++numElem;
}
}
else
{
GetFeatureValueSparse(ch, cursor, catMetaData, rand, out ReadOnlySpan<int> featureIndices, out ReadOnlySpan<float> featureValues, out int cnt);
for (int i = 0; i < cnt; ++i)
{
int colIdx = featureIndices[i];
float fv = featureValues[i];
if (fv == 0)
continue;
features[numElem] = fv;
indices[numElem] = colIdx;
++numElem;
}
}
}
}
private static double DefaultLearningRate(int numRow, bool useCat, int totalCats)
{
if (useCat)
{
if (totalCats < 1e6)
return 0.1;
else
return 0.15;
}
else if (numRow <= 100000)
return 0.2;
else
return 0.25;
}
private static int DefaultNumLeaves(int numRow, bool useCat, int totalCats)
{
if (useCat && totalCats > 100)
{
if (totalCats < 1e6)
return 20;
else
return 30;
}
else if (numRow <= 100000)
return 20;
else
return 30;
}
private protected static int DefaultMinDataPerLeaf(int numRow, int numberOfLeaves, int numClass)
{
if (numClass > 1)
{
int ret = numRow / numberOfLeaves / numClass / 10;
ret = Math.Max(ret, 5);
ret = Math.Min(ret, 50);
return ret;
}
else
{
return 20;
}
}
private static int GetNumSampleRow(int numRow, int numCol)
{
// Default is 65536.
int ret = 1 << 16;
// If have many features, use more sampling data.
if (numCol >= 100000)
ret *= 4;
ret = Math.Min(ret, numRow);
return ret;
}
private protected abstract TModel CreatePredictor();
/// <summary>
/// This function will be called before training. It will check the label/group and add parameters for specific applications.
/// </summary>
private protected abstract void CheckAndUpdateParametersBeforeTraining(IChannel ch,
RoleMappedData data, float[] labels, int[] groups);
}
}
|