|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.OneDal;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
[assembly: LoadableClass(FastForestRegressionTrainer.Summary, typeof(FastForestRegressionTrainer), typeof(FastForestRegressionTrainer.Options),
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
FastForestRegressionTrainer.UserNameValue,
FastForestRegressionTrainer.LoadNameValue,
FastForestRegressionTrainer.ShortName)]
[assembly: LoadableClass(typeof(FastForestRegressionModelParameters), null, typeof(SignatureLoadModel),
"FastForest Regression Executor",
FastForestRegressionModelParameters.LoaderSignature)]
namespace Microsoft.ML.Trainers.FastTree
{
/// <summary>
/// Model parameters for <see cref="FastForestRegressionTrainer"/>.
/// </summary>
public sealed class FastForestRegressionModelParameters :
TreeEnsembleModelParametersBasedOnQuantileRegressionTree,
IQuantileValueMapper,
IQuantileRegressionPredictor,
ISingleCanSaveOnnx
{
private sealed class QuantileStatistics
{
private readonly float[] _data;
private readonly float[] _weights;
//This holds the cumulative sum of _weights to search the rank easily by binary search.
private float[] _weightedSums;
private SummaryStatistics _summaryStatistics;
/// <summary>
/// data array will be modified because of sorting if it is not already sorted yet and this class owns the data.
/// Modifying the data outside will lead to erroneous output by this class
/// </summary>
public QuantileStatistics(float[] data, float[] weights = null, bool isSorted = false)
{
Contracts.CheckValue(data, nameof(data));
Contracts.Check(weights == null || weights.Length == data.Length, "weights");
_data = data;
_weights = weights;
if (!isSorted)
Array.Sort(_data);
else
Contracts.Assert(Utils.IsMonotonicallyIncreasing(_data));
}
/// <summary>
/// There are many ways to estimate quantile. This implementations is based on R-8, SciPy-(1/3,1/3)
/// https://en.wikipedia.org/wiki/Quantile#Estimating_the_quantiles_of_a_population
/// </summary>
public float GetQuantile(float p)
{
Contracts.CheckParam(0 <= p && p <= 1, nameof(p), "Probability argument for Quantile function should be between 0 to 1 inclusive");
if (_data.Length == 0)
return float.NaN;
if (p == 0 || _data.Length == 1)
return _data[0];
if (p == 1)
return _data[_data.Length - 1];
float h = GetRank(p);
if (h <= 1)
return _data[0];
if (h >= _data.Length)
return _data[_data.Length - 1];
var hf = (int)h;
return (float)(_data[hf - 1] + (h - hf) * (_data[hf] - _data[hf - 1]));
}
private float GetRank(float p)
{
const float oneThird = (float)1 / 3;
// holds length of the _data array if the weights is null or holds the sum of weights
float weightedLength = _data.Length;
if (_weights != null)
{
if (_weightedSums == null)
{
_weightedSums = new float[_weights.Length];
_weightedSums[0] = _weights[0];
for (int i = 1; i < _weights.Length; i++)
_weightedSums[i] = _weights[i] + _weightedSums[i - 1];
}
weightedLength = _weightedSums[_weightedSums.Length - 1];
}
// This implementations is based on R-8, SciPy-(1/3,1/3)
// https://en.wikipedia.org/wiki/Quantile#Estimating_the_quantiles_of_a_population
var h = (_weights == null) ? (weightedLength + oneThird) * p + oneThird : weightedLength * p;
if (_weights == null)
return h;
return _weightedSums.FindIndexSorted(h);
}
private SummaryStatistics SummaryStatistics
{
get
{
if (_summaryStatistics == null)
{
_summaryStatistics = new SummaryStatistics();
if (_weights != null)
{
for (int i = 0; i < _data.Length; i++)
_summaryStatistics.Add(_data[i], _weights[i]);
}
else
{
for (int i = 0; i < _data.Length; i++)
_summaryStatistics.Add(_data[i]);
}
}
return _summaryStatistics;
}
}
}
private readonly int _quantileSampleCount;
internal const string LoaderSignature = "FastForestRegressionExec";
internal const string RegistrationName = "FastForestRegressionPredictor";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "FFORE RE",
// verWrittenCur: 0x00010001, Initial
// verWrittenCur: 0x00010002, // InstanceWeights are part of QuantileRegression Tree to support weighted instances
// verWrittenCur: 0x00010003, // _numFeatures serialized
// verWrittenCur: 0x00010004, // Ini content out of predictor
// verWrittenCur: 0x00010005, // Add _defaultValueForMissing
verWrittenCur: 0x00010006, // Categorical splits.
verReadableCur: 0x00010005,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(FastForestRegressionModelParameters).Assembly.FullName);
}
private protected override uint VerNumFeaturesSerialized => 0x00010003;
private protected override uint VerDefaultValueSerialized => 0x00010005;
private protected override uint VerCategoricalSplitSerialized => 0x00010006;
internal FastForestRegressionModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
{
_quantileSampleCount = samplesCount;
}
private FastForestRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx, GetVersionInfo())
{
// *** Binary format ***
// bool: should be always true
// int: Quantile sample count
Contracts.Check(ctx.Reader.ReadBoolByte());
_quantileSampleCount = ctx.Reader.ReadInt32();
}
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
// bool: always true
// int: Quantile sample count
// Previously we store quantileEnabled parameter here,
// but this paramater always should be true for regression.
// If you update model version feel free to delete it.
ctx.Writer.WriteBoolByte(true);
ctx.Writer.Write(_quantileSampleCount);
}
internal static FastForestRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new FastForestRegressionModelParameters(env, ctx);
}
private protected override PredictionKind PredictionKind => PredictionKind.Regression;
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
// Mapping score to prediction
var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
var numTrees = ctx.AddInitializer((float)TrainedEnsemble.NumTrees, "NumTrees");
base.SaveAsOnnx(ctx, new[] { fastTreeOutput }, featureColumn);
var opType = "Div";
ctx.CreateNode(opType, new[] { fastTreeOutput, numTrees }, outputNames, ctx.GetNodeName(opType), "");
return true;
}
private protected override void Map(in VBuffer<float> src, ref float dst)
{
int inputVectorSize = InputType.GetVectorSize();
if (inputVectorSize > 0)
Host.Check(src.Length == inputVectorSize);
else
Host.Check(src.Length > MaxSplitFeatIdx);
dst = (float)TrainedEnsemble.GetOutput(in src) / TrainedEnsemble.NumTrees;
}
ValueMapper<VBuffer<float>, VBuffer<float>> IQuantileValueMapper.GetMapper(float[] quantiles)
{
return
(in VBuffer<float> src, ref VBuffer<float> dst) =>
{
// REVIEW: Should make this more efficient - it repeatedly allocates too much stuff.
float[] weights = null;
var distribution = TrainedEnsemble.GetDistribution(in src, _quantileSampleCount, out weights);
QuantileStatistics qdist = new QuantileStatistics(distribution, weights);
var editor = VBufferEditor.Create(ref dst, quantiles.Length);
for (int i = 0; i < quantiles.Length; i++)
editor.Values[i] = qdist.GetQuantile((float)quantiles[i]);
dst = editor.Commit();
};
}
ISchemaBindableMapper IQuantileRegressionPredictor.CreateMapper(Double[] quantiles)
{
Host.CheckNonEmpty(quantiles, nameof(quantiles));
return new SchemaBindableQuantileRegressionPredictor(this, quantiles);
}
}
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree regression model using Fast Forest.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [FastForest](xref:Microsoft.ML.TreeExtensions.FastForest(Microsoft.ML.RegressionCatalog.RegressionTrainers,System.String,System.String,System.String,System.Int32,System.Int32,System.Int32))
/// or [FastForest(Options)](xref:Microsoft.ML.TreeExtensions.FastForest(Microsoft.ML.RegressionCatalog.RegressionTrainers,Microsoft.ML.Trainers.FastTree.FastForestRegressionTrainer.Options)).
///
/// [!include[io](~/../docs/samples/docs/api-reference/io-columns-regression.md)]
///
/// ### Trainer Characteristics
/// | | |
/// | -- | -- |
/// | Machine learning task | Regression |
/// | Is normalization required? | No |
/// | Is caching required? | No |
/// | Required NuGet in addition to Microsoft.ML | Microsoft.ML.FastTree |
/// | Exportable to ONNX | Yes |
///
/// [!include[algorithm](~/../docs/samples/docs/api-reference/algo-details-fastforest.md)]
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="TreeExtensions.FastForest(RegressionCatalog.RegressionTrainers, string, string, string, int, int, int)"/>
/// <seealso cref="TreeExtensions.FastForest(RegressionCatalog.RegressionTrainers, FastForestRegressionTrainer.Options)"/>
/// <seealso cref="Options"/>
public sealed partial class FastForestRegressionTrainer
: RandomForestTrainerBase<FastForestRegressionTrainer.Options, RegressionPredictionTransformer<FastForestRegressionModelParameters>, FastForestRegressionModelParameters>
{
/// <summary>
/// Options for the <see cref="FastForestRegressionTrainer"/> as used in
/// [FastForest(Options)](xref:Microsoft.ML.TreeExtensions.FastForest(Microsoft.ML.RegressionCatalog.RegressionTrainers,Microsoft.ML.Trainers.FastTree.FastForestRegressionTrainer.Options)).
/// </summary>
public sealed class Options : FastForestOptionsBase
{
/// <summary>
/// Whether to shuffle the labels on every iteration.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Shuffle the labels on every iteration. " +
"Useful probably only if using this tree as a tree leaf featurizer for multiclass.")]
public bool ShuffleLabels;
}
private protected override PredictionKind PredictionKind => PredictionKind.Regression;
internal const string Summary = "Trains a random forest to fit target values using least-squares.";
internal const string LoadNameValue = "FastForestRegression";
internal const string UserNameValue = "Fast Forest Regression";
internal const string ShortName = "ffr";
/// <summary>
/// Initializes a new instance of <see cref="FastForestRegressionTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumnName">The name of the label column.</param>
/// <param name="featureColumnName">The name of the feature column.</param>
/// <param name="exampleWeightColumnName">The optional name for the column containing the example weight.</param>
/// <param name="numberOfLeaves">The max number of leaves in each regression tree.</param>
/// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="minimumExampleCountPerLeaf">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
internal FastForestRegressionTrainer(IHostEnvironment env,
string labelColumnName = DefaultColumnNames.Label,
string featureColumnName = DefaultColumnNames.Features,
string exampleWeightColumnName = null,
int numberOfLeaves = Defaults.NumberOfLeaves,
int numberOfTrees = Defaults.NumberOfTrees,
int minimumExampleCountPerLeaf = Defaults.MinimumExampleCountPerLeaf)
: base(env, TrainerUtils.MakeR4ScalarColumn(labelColumnName), featureColumnName, exampleWeightColumnName, null, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf)
{
Host.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
Host.CheckNonEmpty(featureColumnName, nameof(featureColumnName));
}
/// <summary>
/// Initializes a new instance of <see cref="FastForestRegressionTrainer"/> by using the <see cref="Options"/> class.
/// </summary>
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="options">Algorithm advanced settings.</param>
internal FastForestRegressionTrainer(IHostEnvironment env, Options options)
: base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumnName), true)
{
}
private protected override FastForestRegressionModelParameters TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var trainData = context.TrainingSet;
ValidData = context.ValidationSet;
TestData = context.TestSet;
using (var ch = Host.Start("Training"))
{
ch.CheckValue(trainData, nameof(trainData));
trainData.CheckRegressionLabel();
trainData.CheckFeatureFloatVector();
trainData.CheckOptFloatWeight();
FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount();
ConvertData(trainData);
if (!trainData.Schema.Weight.HasValue && MLContext.OneDalDispatchingEnabled)
{
if (FastTreeTrainerOptions.FeatureFraction != 1.0)
{
ch.Warning($"oneDAL decision forest doesn't support 'FeatureFraction'[per tree] != 1.0, changing it from {FastTreeTrainerOptions.FeatureFraction} to 1.0");
FastTreeTrainerOptions.FeatureFraction = 1.0;
}
CursOpt cursorOpt = CursOpt.Label | CursOpt.Features;
var cursorFactory = new FloatLabelCursor.Factory(trainData, cursorOpt);
TrainCoreOneDal(ch, cursorFactory, FeatureCount);
if (FeatureMap != null)
TrainedEnsemble.RemapFeatures(FeatureMap);
}
else
{
TrainCore(ch);
}
}
return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions, FastTreeTrainerOptions.NumberOfQuantileSamples);
}
internal static class OneDal
{
private const string OneDalLibPath = "OneDalNative";
[DllImport(OneDalLibPath, EntryPoint = "decisionForestRegressionCompute")]
public static extern unsafe int DecisionForestRegressionCompute(
void* featuresPtr, void* labelsPtr, long nRows, int nColumns, int numberOfThreads,
float featureFractionPerSplit, int numberOfTrees, int numberOfLeaves, int minimumExampleCountPerLeaf, int maxBins,
void* lteChildPtr, void* gtChildPtr, void* splitFeaturePtr, void* featureThresholdPtr, void* leafValuesPtr, void* modelPtr);
}
[BestFriend]
private void TrainCoreOneDal(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
{
CheckOptions(ch);
Initialize(ch);
List<float> featuresList = new List<float>();
List<float> labelsList = new List<float>();
int numberOfLeaves = FastTreeTrainerOptions.NumberOfLeaves;
int numberOfTrees = FastTreeTrainerOptions.NumberOfTrees;
int numberOfThreads = 0;
if (FastTreeTrainerOptions.NumberOfThreads.HasValue)
numberOfThreads = FastTreeTrainerOptions.NumberOfThreads.Value;
long n = OneDalUtils.GetTrainData(ch, cursorFactory, ref featuresList, ref labelsList, featureCount);
float[] featuresArray = featuresList.ToArray();
float[] labelsArray = labelsList.ToArray();
int[] lteChildArray = new int[(numberOfLeaves - 1) * numberOfTrees];
int[] gtChildArray = new int[(numberOfLeaves - 1) * numberOfTrees];
int[] splitFeatureArray = new int[(numberOfLeaves - 1) * numberOfTrees];
float[] featureThresholdArray = new float[(numberOfLeaves - 1) * numberOfTrees];
float[] leafValuesArray = new float[numberOfLeaves * numberOfTrees];
int oneDalModelSize = -1;
int projectedOneDalModelSize = 96 * 1 * numberOfLeaves * numberOfTrees + 4096 * 16;
byte[] oneDalModel = new byte[projectedOneDalModelSize];
unsafe
{
#pragma warning disable MSML_SingleVariableDeclaration // Have only a single variable present per declaration
fixed (void* featuresPtr = &featuresArray[0], labelsPtr = &labelsArray[0],
lteChildPtr = <eChildArray[0], gtChildPtr = >ChildArray[0], splitFeaturePtr = &splitFeatureArray[0],
featureThresholdPtr = &featureThresholdArray[0], leafValuesPtr = &leafValuesArray[0], oneDalModelPtr = &oneDalModel[0])
#pragma warning restore MSML_SingleVariableDeclaration // Have only a single variable present per declaration
{
oneDalModelSize = OneDal.DecisionForestRegressionCompute(featuresPtr, labelsPtr, n, featureCount,
numberOfThreads, (float)FastTreeTrainerOptions.FeatureFractionPerSplit, numberOfTrees,
numberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.MaximumBinCountPerFeature,
lteChildPtr, gtChildPtr, splitFeaturePtr, featureThresholdPtr, leafValuesPtr, oneDalModelPtr
);
}
}
TrainedEnsemble = new InternalTreeEnsemble();
for (int i = 0; i < numberOfTrees; ++i)
{
int[] lteChildArrayPerTree = new int[numberOfLeaves - 1];
int[] gtChildArrayPerTree = new int[numberOfLeaves - 1];
int[] splitFeatureArrayPerTree = new int[numberOfLeaves - 1];
float[] featureThresholdArrayPerTree = new float[numberOfLeaves - 1];
double[] leafValuesArrayPerTree = new double[numberOfLeaves];
int[][] categoricalSplitFeaturesPerTree = new int[numberOfLeaves - 1][];
bool[] categoricalSplitPerTree = new bool[numberOfLeaves - 1];
double[] splitGainPerTree = new double[numberOfLeaves - 1];
float[] defaultValueForMissingPerTree = new float[numberOfLeaves - 1];
for (int j = 0; j < numberOfLeaves - 1; ++j)
{
lteChildArrayPerTree[j] = lteChildArray[(numberOfLeaves - 1) * i + j];
gtChildArrayPerTree[j] = gtChildArray[(numberOfLeaves - 1) * i + j];
splitFeatureArrayPerTree[j] = splitFeatureArray[(numberOfLeaves - 1) * i + j];
featureThresholdArrayPerTree[j] = featureThresholdArray[(numberOfLeaves - 1) * i + j];
leafValuesArrayPerTree[j] = leafValuesArray[numberOfLeaves * i + j];
categoricalSplitFeaturesPerTree[j] = null;
categoricalSplitPerTree[j] = false;
splitGainPerTree[j] = 0.0;
defaultValueForMissingPerTree[j] = 0.0f;
}
leafValuesArrayPerTree[numberOfLeaves - 1] = leafValuesArray[numberOfLeaves * i + numberOfLeaves - 1];
InternalQuantileRegressionTree newTree = new InternalQuantileRegressionTree(splitFeatureArrayPerTree, splitGainPerTree, null,
featureThresholdArrayPerTree, defaultValueForMissingPerTree, lteChildArrayPerTree, gtChildArrayPerTree, leafValuesArrayPerTree,
categoricalSplitFeaturesPerTree, categoricalSplitPerTree);
newTree.PopulateThresholds(TrainSet);
TrainedEnsemble.AddTree(newTree);
}
}
private protected override void PrepareLabels(IChannel ch)
{
}
private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
{
return ObjectiveFunctionImplBase.Create(TrainSet, FastTreeTrainerOptions);
}
private protected override Test ConstructTestForTrainingData()
{
return new RegressionTest(ConstructScoreTracker(TrainSet));
}
private protected override RegressionPredictionTransformer<FastForestRegressionModelParameters> MakeTransformer(FastForestRegressionModelParameters model, DataViewSchema trainSchema)
=> new RegressionPredictionTransformer<FastForestRegressionModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
/// <summary>
/// Trains a <see cref="FastForestRegressionTrainer"/> using both training and validation data, returns
/// a <see cref="RegressionPredictionTransformer{FastForestRegressionModelParameters}"/>.
/// </summary>
public RegressionPredictionTransformer<FastForestRegressionModelParameters> Fit(IDataView trainData, IDataView validationData)
=> TrainTransformer(trainData, validationData);
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation()))
};
}
private abstract class ObjectiveFunctionImplBase : RandomForestObjectiveFunction
{
private readonly float[] _labels;
public static ObjectiveFunctionImplBase Create(Dataset trainData, Options options)
{
if (options.ShuffleLabels)
return new ShuffleImpl(trainData, options);
return new BasicImpl(trainData, options);
}
private ObjectiveFunctionImplBase(Dataset trainData, Options options)
: base(trainData, options, double.MaxValue) // No notion of maximum step size.
{
_labels = FastTreeRegressionTrainer.GetDatasetRegressionLabels(trainData);
Contracts.Assert(_labels.Length == trainData.NumDocs);
}
protected override void GetGradientInOneQuery(int query, int threadIndex)
{
int begin = Dataset.Boundaries[query];
int end = Dataset.Boundaries[query + 1];
for (int i = begin; i < end; ++i)
Gradient[i] = _labels[i];
}
private sealed class ShuffleImpl : ObjectiveFunctionImplBase
{
private readonly Random _rgen;
private readonly int _labelLim;
public ShuffleImpl(Dataset trainData, Options options)
: base(trainData, options)
{
Contracts.AssertValue(options);
Contracts.Assert(options.ShuffleLabels);
_rgen = new Random(0); // Ideally we'd get this from the host.
for (int i = 0; i < _labels.Length; ++i)
{
var lab = _labels[i];
if (!(0 <= lab && lab < Utils.ArrayMaxSize))
{
throw Contracts.ExceptUserArg(nameof(options.ShuffleLabels),
"Label {0} for example {1} outside of allowed range" +
"[0,{2}) when doing shuffled labels", lab, i, Utils.ArrayMaxSize);
}
int lim = (int)lab + 1;
Contracts.Assert(1 <= lim && lim <= Utils.ArrayMaxSize);
if (lim > _labelLim)
_labelLim = lim;
}
}
public override double[] GetGradient(IChannel ch, double[] scores)
{
// Each time we get the gradient in random forest regression, it means
// we are building a new tree. Shuffle the targets!!
int[] map = Utils.GetRandomPermutation(_rgen, _labelLim);
for (int i = 0; i < _labels.Length; ++i)
_labels[i] = map[(int)_labels[i]];
return base.GetGradient(ch, scores);
}
}
private sealed class BasicImpl : ObjectiveFunctionImplBase
{
public BasicImpl(Dataset trainData, Options options)
: base(trainData, options)
{
}
}
}
}
internal static partial class FastForest
{
[TlcModule.EntryPoint(Name = "Trainers.FastForestRegressor",
Desc = FastForestRegressionTrainer.Summary,
UserName = FastForestRegressionTrainer.LoadNameValue,
ShortName = FastForestRegressionTrainer.ShortName)]
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, FastForestRegressionTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainFastForest");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return TrainerEntryPointsUtils.Train<FastForestRegressionTrainer.Options, CommonOutputs.RegressionOutput>(host, input,
() => new FastForestRegressionTrainer(host, input),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.RowGroupColumnName));
}
}
}
|