|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.OneDal;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
[assembly: LoadableClass(FastForestBinaryTrainer.Summary, typeof(FastForestBinaryTrainer), typeof(FastForestBinaryTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
FastForestBinaryTrainer.UserNameValue,
FastForestBinaryTrainer.LoadNameValue,
"FastForest",
FastForestBinaryTrainer.ShortName,
"ffc")]
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(FastForestBinaryModelParameters), null, typeof(SignatureLoadModel),
"FastForest Binary Executor",
FastForestBinaryModelParameters.LoaderSignature)]
[assembly: LoadableClass(typeof(void), typeof(FastForest), null, typeof(SignatureEntryPointModule), "FastForest")]
namespace Microsoft.ML.Trainers.FastTree
{
/// <summary>
/// Base class for fast forest trainer options.
/// </summary>
public abstract class FastForestOptionsBase : TreeOptions
{
/// <summary>
/// The number of data points to be sampled from each leaf to find the distribution of labels.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of labels to be sampled from each leaf to make the distribution", ShortName = "qsc")]
public int NumberOfQuantileSamples = 100;
internal FastForestOptionsBase()
{
FeatureFraction = 0.7;
BaggingSize = 1;
FeatureFractionPerSplit = 0.7;
}
}
/// <summary>
/// Model parameters for <see cref="FastForestBinaryTrainer"/>.
/// </summary>
public sealed class FastForestBinaryModelParameters :
TreeEnsembleModelParametersBasedOnQuantileRegressionTree
{
internal const string LoaderSignature = "FastForestBinaryExec";
internal const string RegistrationName = "FastForestClassificationPredictor";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "FFORE BC",
// verWrittenCur: 0x00010001, Initial
// verWrittenCur: 0x00010002, // InstanceWeights are part of QuantileRegression Tree to support weighted instances
// verWrittenCur: 0x00010003, // _numFeatures serialized
// verWrittenCur: 0x00010004, // Ini content out of predictor
// verWrittenCur: 0x00010005, // Add _defaultValueForMissing
verWrittenCur: 0x00010006, // Categorical splits.
verReadableCur: 0x00010005,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(FastForestBinaryModelParameters).Assembly.FullName);
}
private protected override uint VerNumFeaturesSerialized => 0x00010003;
private protected override uint VerDefaultValueSerialized => 0x00010005;
private protected override uint VerCategoricalSplitSerialized => 0x00010006;
/// <summary>
/// The type of prediction for this trainer.
/// </summary>
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
internal FastForestBinaryModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
{ }
private FastForestBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx, GetVersionInfo())
{
}
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
}
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
var predictor = new FastForestBinaryModelParameters(env, ctx);
ICalibrator calibrator;
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
if (calibrator == null)
return predictor;
return new SchemaBindableCalibratedModelParameters<FastForestBinaryModelParameters, ICalibrator>(env, predictor, calibrator);
}
}
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree binary classification model using Fast Forest.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [FastForest](xref:Microsoft.ML.TreeExtensions.FastForest(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,System.String,System.String,System.String,System.Int32,System.Int32,System.Int32))
/// or [FastForest(Options)](xref:Microsoft.ML.TreeExtensions.FastForest(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,Microsoft.ML.Trainers.FastTree.FastForestBinaryTrainer.Options)).
///
/// [!include[io](~/../docs/samples/docs/api-reference/io-columns-binary-classification.md)]
///
/// ### Trainer Characteristics
/// | | |
/// | -- | -- |
/// | Machine learning task | Binary classification |
/// | Is normalization required? | No |
/// | Is caching required? | No |
/// | Required NuGet in addition to Microsoft.ML | Microsoft.ML.FastTree |
/// | Exportable to ONNX | Yes |
///
/// [!include[algorithm](~/../docs/samples/docs/api-reference/algo-details-fastforest.md)]
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="TreeExtensions.FastForest(BinaryClassificationCatalog.BinaryClassificationTrainers, string, string, string, int, int, int)"/>
/// <seealso cref="TreeExtensions.FastForest(BinaryClassificationCatalog.BinaryClassificationTrainers, FastForestBinaryTrainer.Options)"/>
/// <seealso cref="Options"/>
public sealed partial class FastForestBinaryTrainer :
RandomForestTrainerBase<FastForestBinaryTrainer.Options, BinaryPredictionTransformer<FastForestBinaryModelParameters>, FastForestBinaryModelParameters>
{
/// <summary>
/// Options for the <see cref="FastForestBinaryTrainer"/> as used in
/// [FastForest(Options)](xref:Microsoft.ML.TreeExtensions.FastForest(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,Microsoft.ML.Trainers.FastTree.FastForestBinaryTrainer.Options)).
/// </summary>
public sealed class Options : FastForestOptionsBase
{
/// <summary>
/// The upper bound on the absolute value of a single tree output.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Upper bound on absolute value of single tree output", ShortName = "mo")]
public Double MaximumOutputMagnitudePerTree = 100;
[Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory();
[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal int MaxCalibrationExamples = 1000000;
}
internal const string LoadNameValue = "FastForestClassification";
internal const string UserNameValue = "Fast Forest Classification";
internal const string Summary = "Uses a random forest learner to perform binary classification.";
internal const string ShortName = "ff";
private bool[] _trainSetLabels;
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
private protected override bool NeedCalibration => true;
/// <summary>
/// Initializes a new instance of <see cref="FastForestBinaryTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumnName">The name of the label column.</param>
/// <param name="featureColumnName">The name of the feature column.</param>
/// <param name="exampleWeightColumnName">The name for the column containing the example weight.</param>
/// <param name="numberOfLeaves">The max number of leaves in each regression tree.</param>
/// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="minimumExampleCountPerLeaf">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
internal FastForestBinaryTrainer(IHostEnvironment env,
string labelColumnName = DefaultColumnNames.Label,
string featureColumnName = DefaultColumnNames.Features,
string exampleWeightColumnName = null,
int numberOfLeaves = Defaults.NumberOfLeaves,
int numberOfTrees = Defaults.NumberOfTrees,
int minimumExampleCountPerLeaf = Defaults.MinimumExampleCountPerLeaf)
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumnName), featureColumnName, exampleWeightColumnName, null, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf)
{
Host.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
Host.CheckNonEmpty(featureColumnName, nameof(featureColumnName));
}
/// <summary>
/// Initializes a new instance of <see cref="FastForestBinaryTrainer"/> by using the <see cref="Options"/> class.
/// </summary>
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="options">Algorithm advanced settings.</param>
internal FastForestBinaryTrainer(IHostEnvironment env, Options options)
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
{
}
private protected override FastForestBinaryModelParameters TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var trainData = context.TrainingSet;
ValidData = context.ValidationSet;
TestData = context.TestSet;
using (var ch = Host.Start("Training"))
{
ch.CheckValue(trainData, nameof(trainData));
trainData.CheckBinaryLabel();
trainData.CheckFeatureFloatVector();
trainData.CheckOptFloatWeight();
FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount();
ConvertData(trainData);
if (!trainData.Schema.Weight.HasValue && MLContext.OneDalDispatchingEnabled)
{
if (FastTreeTrainerOptions.FeatureFraction != 1.0)
{
ch.Warning($"oneDAL decision forest doesn't support 'FeatureFraction'[per tree] != 1.0, changing it from {FastTreeTrainerOptions.FeatureFraction} to 1.0");
FastTreeTrainerOptions.FeatureFraction = 1.0;
}
CursOpt cursorOpt = CursOpt.Label | CursOpt.Features;
var cursorFactory = new FloatLabelCursor.Factory(trainData, cursorOpt);
TrainCoreOneDal(ch, cursorFactory, FeatureCount);
if (FeatureMap != null)
TrainedEnsemble.RemapFeatures(FeatureMap);
}
else
{
TrainCore(ch);
}
}
// LogitBoost is naturally calibrated to
// output probabilities when transformed using
// the logistic function, so if we have trained no
// calibrator, transform the scores using that.
// REVIEW: Need a way to signal the outside world that we prefer simple sigmoid?
return new FastForestBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions);
}
internal static class OneDal
{
private const string OneDalLibPath = "OneDalNative";
[DllImport(OneDalLibPath, EntryPoint = "decisionForestClassificationCompute")]
public static extern unsafe int DecisionForestClassificationCompute(
void* featuresPtr, void* labelsPtr, long nRows, int nColumns, int nClasses, int numberOfThreads,
float featureFractionPerSplit, int numberOfTrees, int numberOfLeaves, int minimumExampleCountPerLeaf, int maxBins,
void* lteChildPtr, void* gtChildPtr, void* splitFeaturePtr, void* featureThresholdPtr, void* leafValuesPtr, void* modelPtr);
}
[BestFriend]
private void TrainCoreOneDal(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
{
CheckOptions(ch);
Initialize(ch);
List<float> featuresList = new List<float>();
List<float> labelsList = new List<float>();
int nClasses = 2;
int numberOfLeaves = FastTreeTrainerOptions.NumberOfLeaves;
int numberOfTrees = FastTreeTrainerOptions.NumberOfTrees;
int numberOfThreads = 0;
if (FastTreeTrainerOptions.NumberOfThreads.HasValue)
numberOfThreads = FastTreeTrainerOptions.NumberOfThreads.Value;
long n = OneDalUtils.GetTrainData(ch, cursorFactory, ref featuresList, ref labelsList, featureCount);
float[] featuresArray = featuresList.ToArray();
float[] labelsArray = labelsList.ToArray();
int[] lteChildArray = new int[(numberOfLeaves - 1) * numberOfTrees];
int[] gtChildArray = new int[(numberOfLeaves - 1) * numberOfTrees];
int[] splitFeatureArray = new int[(numberOfLeaves - 1) * numberOfTrees];
float[] featureThresholdArray = new float[(numberOfLeaves - 1) * numberOfTrees];
float[] leafValuesArray = new float[numberOfLeaves * numberOfTrees];
int oneDalModelSize = -1;
int projectedOneDalModelSize = 96 * nClasses * numberOfLeaves * numberOfTrees + 4096 * 16;
byte[] oneDalModel = new byte[projectedOneDalModelSize];
unsafe
{
#pragma warning disable MSML_SingleVariableDeclaration // Have only a single variable present per declaration
fixed (void* featuresPtr = &featuresArray[0], labelsPtr = &labelsArray[0],
lteChildPtr = <eChildArray[0], gtChildPtr = >ChildArray[0], splitFeaturePtr = &splitFeatureArray[0],
featureThresholdPtr = &featureThresholdArray[0], leafValuesPtr = &leafValuesArray[0], oneDalModelPtr = &oneDalModel[0])
#pragma warning restore MSML_SingleVariableDeclaration // Have only a single variable present per declaration
{
oneDalModelSize = OneDal.DecisionForestClassificationCompute(featuresPtr, labelsPtr, n, featureCount, nClasses,
numberOfThreads, (float)FastTreeTrainerOptions.FeatureFractionPerSplit, numberOfTrees,
numberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.MaximumBinCountPerFeature,
lteChildPtr, gtChildPtr, splitFeaturePtr, featureThresholdPtr, leafValuesPtr, oneDalModelPtr
);
}
}
TrainedEnsemble = new InternalTreeEnsemble();
for (int i = 0; i < numberOfTrees; ++i)
{
int[] lteChildArrayPerTree = new int[numberOfLeaves - 1];
int[] gtChildArrayPerTree = new int[numberOfLeaves - 1];
int[] splitFeatureArrayPerTree = new int[numberOfLeaves - 1];
float[] featureThresholdArrayPerTree = new float[numberOfLeaves - 1];
double[] leafValuesArrayPerTree = new double[numberOfLeaves];
int[][] categoricalSplitFeaturesPerTree = new int[numberOfLeaves - 1][];
bool[] categoricalSplitPerTree = new bool[numberOfLeaves - 1];
double[] splitGainPerTree = new double[numberOfLeaves - 1];
float[] defaultValueForMissingPerTree = new float[numberOfLeaves - 1];
for (int j = 0; j < numberOfLeaves - 1; ++j)
{
lteChildArrayPerTree[j] = lteChildArray[(numberOfLeaves - 1) * i + j];
gtChildArrayPerTree[j] = gtChildArray[(numberOfLeaves - 1) * i + j];
splitFeatureArrayPerTree[j] = splitFeatureArray[(numberOfLeaves - 1) * i + j];
featureThresholdArrayPerTree[j] = featureThresholdArray[(numberOfLeaves - 1) * i + j];
leafValuesArrayPerTree[j] = leafValuesArray[numberOfLeaves * i + j];
categoricalSplitFeaturesPerTree[j] = null;
categoricalSplitPerTree[j] = false;
splitGainPerTree[j] = 0.0;
defaultValueForMissingPerTree[j] = 0.0f;
}
leafValuesArrayPerTree[numberOfLeaves - 1] = leafValuesArray[numberOfLeaves * i + numberOfLeaves - 1];
InternalQuantileRegressionTree newTree = new InternalQuantileRegressionTree(splitFeatureArrayPerTree, splitGainPerTree, null,
featureThresholdArrayPerTree, defaultValueForMissingPerTree, lteChildArrayPerTree, gtChildArrayPerTree, leafValuesArrayPerTree,
categoricalSplitFeaturesPerTree, categoricalSplitPerTree);
newTree.PopulateThresholds(TrainSet);
TrainedEnsemble.AddTree(newTree);
}
}
private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
{
return new ObjectiveFunctionImpl(TrainSet, _trainSetLabels, FastTreeTrainerOptions);
}
private protected override void PrepareLabels(IChannel ch)
{
// REVIEW: Historically FastTree has this test as >= 1. TLC however
// generally uses > 0. Consider changing FastTree to be consistent.
_trainSetLabels = TrainSet.Ratings.Select(x => x >= 1).ToArray(TrainSet.NumDocs);
}
private protected override Test ConstructTestForTrainingData()
{
return new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, 1);
}
private protected override BinaryPredictionTransformer<FastForestBinaryModelParameters> MakeTransformer(FastForestBinaryModelParameters model, DataViewSchema trainSchema)
=> new BinaryPredictionTransformer<FastForestBinaryModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
/// <summary>
/// Trains a <see cref="FastForestBinaryTrainer"/> using both training and validation data, returns
/// a <see cref="BinaryPredictionTransformer{FastForestClassificationModelParameters}"/>.
/// </summary>
public BinaryPredictionTransformer<FastForestBinaryModelParameters> Fit(IDataView trainData, IDataView validationData)
=> TrainTransformer(trainData, validationData);
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation())),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation()))
};
}
private sealed class ObjectiveFunctionImpl : RandomForestObjectiveFunction
{
private readonly bool[] _labels;
public ObjectiveFunctionImpl(Dataset trainSet, bool[] trainSetLabels, Options options)
: base(trainSet, options, options.MaximumOutputMagnitudePerTree)
{
_labels = trainSetLabels;
}
protected override void GetGradientInOneQuery(int query, int threadIndex)
{
int begin = Dataset.Boundaries[query];
int end = Dataset.Boundaries[query + 1];
for (int i = begin; i < end; ++i)
Gradient[i] = _labels[i] ? 1 : -1;
}
}
}
internal static partial class FastForest
{
[TlcModule.EntryPoint(Name = "Trainers.FastForestBinaryClassifier",
Desc = FastForestBinaryTrainer.Summary,
UserName = FastForestBinaryTrainer.UserNameValue,
ShortName = FastForestBinaryTrainer.ShortName)]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastForestBinaryTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainFastForest");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return TrainerEntryPointsUtils.Train<FastForestBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new FastForestBinaryTrainer(host, input),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.RowGroupColumnName),
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
}
}
}
|