|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
[assembly: LoadableClass(typeof(void), typeof(Gam), null, typeof(SignatureEntryPointModule), "GAM")]
namespace Microsoft.ML.Trainers.FastTree
{
using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo;
/// <summary>
/// Base class for GAM trainers.
/// </summary>
public abstract partial class GamTrainerBase<TOptions, TTransformer, TPredictor> : TrainerEstimatorBase<TTransformer, TPredictor>
where TTransformer : ISingleFeaturePredictionTransformer<TPredictor>
where TOptions : GamTrainerBase<TOptions, TTransformer, TPredictor>.OptionsBase, new()
where TPredictor : class
{
/// <summary>
/// Base class for GAM-based trainer options.
/// </summary>
public abstract class OptionsBase : TrainerInputBaseWithWeight
{
/// <summary>
/// The entropy (regularization) coefficient between 0 and 1.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "The entropy (regularization) coefficient between 0 and 1", ShortName = "e")]
public double EntropyCoefficient;
/// <summary>
/// Tree fitting gain confidence requirement. Only consider a gain if its likelihood versus a random choice gain is above this value.
/// </summary>
/// <value>
/// Value of 0.95 would mean restricting to gains that have less than a 0.05 chance of being generated randomly through choice of a random split.
/// Valid range is [0,1).
/// </value>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Tree fitting gain confidence requirement (should be in the range [0,1) ).", ShortName = "gainconf")]
public int GainConfidenceLevel;
/// <summary>
/// Total number of passes over the training data.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Total number of iterations over all features", ShortName = "iter", SortOrder = 1)]
[TGUI(SuggestedSweeps = "200,1500,9500")]
[TlcModule.SweepableDiscreteParamAttribute("NumIterations", new object[] { 200, 1500, 9500 })]
public int NumberOfIterations = GamDefaults.NumberOfIterations;
/// <summary>
/// The number of threads to use.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "The number of threads to use", ShortName = "t", NullName = "<Auto>")]
public int? NumberOfThreads = null;
/// <summary>
/// The learning rate.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "The learning rate", ShortName = "lr", SortOrder = 4)]
[TGUI(SuggestedSweeps = "0.001,0.1;log")]
[TlcModule.SweepableFloatParamAttribute("LearningRates", 0.001f, 0.1f, isLogScale: true)]
public double LearningRate = GamDefaults.LearningRate;
/// <summary>
/// Whether to utilize the disk or the data's native transposition facilities (where applicable) when performing the transpose.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether to utilize the disk or the data's native transposition facilities (where applicable) when performing the transpose", ShortName = "dt")]
public bool? DiskTranspose;
/// <summary>
/// The maximum number of distinct values (bins) per feature.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Maximum number of distinct values (bins) per feature", ShortName = "mb")]
public int MaximumBinCountPerFeature = GamDefaults.MaximumBinCountPerFeature;
/// <summary>
/// The upper bound on the absolute value of a single tree output.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Upper bound on absolute value of single output", ShortName = "mo")]
public double MaximumTreeOutput = double.PositiveInfinity;
/// <summary>
/// Sample each query 1 in k times in the GetDerivatives function.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Sample each query 1 in k times in the GetDerivatives function", ShortName = "sr")]
public int GetDerivativesSampleRate = 1;
/// <summary>
/// The seed of the random number generator.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "The seed of the random number generator", ShortName = "r1")]
public int Seed = 123;
/// <summary>
/// The minimal number of data points required to form a new tree leaf.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Minimum number of training instances required to form a partition", ShortName = "mi", SortOrder = 3)]
[TGUI(SuggestedSweeps = "1,10,50")]
[TlcModule.SweepableDiscreteParamAttribute("MinDocuments", new object[] { 1, 10, 50 })]
public int MinimumExampleCountPerLeaf = 10;
/// <summary>
/// Whether to collectivize features during dataset preparation to speed up training.
/// </summary>
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether to collectivize features during dataset preparation to speed up training", ShortName = "flocks", Hide = true)]
public bool FeatureFlocks = true;
/// <summary>
/// Enable post-training tree pruning to avoid overfitting. It requires a validation set.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Enable post-training pruning to avoid overfitting. (a validation set is required)", ShortName = "pruning")]
public bool EnablePruning = true;
}
internal const string Summary = "Trains a gradient boosted stump per feature, on all features simultaneously, " +
"to fit target values using least-squares. It maintains " +
"no interactions between features.";
private const string RegisterName = "GamTraining";
//Parameters of training
private protected readonly TOptions GamTrainerOptions;
private readonly double _gainConfidenceInSquaredStandardDeviations;
private readonly double _entropyCoefficient;
//Dataset information
private protected Dataset TrainSet;
private protected Dataset ValidSet;
/// <summary>
/// Whether a validation set was passed in
/// </summary>
private protected bool HasValidSet => ValidSet != null;
private protected ScoreTracker TrainSetScore;
private protected ScoreTracker ValidSetScore;
private protected TestHistory PruningTest;
private protected int PruningLossIndex;
private protected int InputLength;
private LeastSquaresRegressionTreeLearner.LeafSplitCandidates _leafSplitCandidates;
private SufficientStatsBase[] _histogram;
private ILeafSplitStatisticsCalculator _leafSplitHelper;
private ObjectiveFunctionBase _objectiveFunction;
private bool HasWeights => TrainSet?.SampleWeights != null;
// Training data structures
private SubGraph _subGraph;
//Results of training
private protected double MeanEffect;
private protected double[][] BinEffects;
private protected double[][] BinUpperBounds;
private protected int[] FeatureMap;
public override TrainerInfo Info { get; }
private protected virtual bool NeedCalibration => false;
private protected IParallelTraining ParallelTraining;
private protected GamTrainerBase(IHostEnvironment env,
string name,
SchemaShape.Column label,
string featureColumnName,
string weightCrowGroupColumnName,
int numberOfIterations,
double learningRate,
int maximumBinCountPerFeature)
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumnName), label, TrainerUtils.MakeR4ScalarWeightColumn(weightCrowGroupColumnName))
{
GamTrainerOptions = new TOptions();
GamTrainerOptions.NumberOfIterations = numberOfIterations;
GamTrainerOptions.LearningRate = learningRate;
GamTrainerOptions.MaximumBinCountPerFeature = maximumBinCountPerFeature;
GamTrainerOptions.LabelColumnName = label.Name;
GamTrainerOptions.FeatureColumnName = featureColumnName;
if (weightCrowGroupColumnName != null)
GamTrainerOptions.ExampleWeightColumnName = weightCrowGroupColumnName;
Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true);
_gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - GamTrainerOptions.GainConfidenceLevel) * 0.5), 2);
_entropyCoefficient = GamTrainerOptions.EntropyCoefficient * 1e-6;
InitializeThreads();
}
private protected GamTrainerBase(IHostEnvironment env, TOptions options, string name, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName),
label, TrainerUtils.MakeR4ScalarWeightColumn(options.ExampleWeightColumnName))
{
Contracts.CheckValue(env, nameof(env));
Host.CheckValue(options, nameof(options));
Host.CheckParam(options.LearningRate > 0, nameof(options.LearningRate), "Must be positive.");
Host.CheckParam(options.NumberOfThreads == null || options.NumberOfThreads > 0, nameof(options.NumberOfThreads), "Must be positive.");
Host.CheckParam(0 <= options.EntropyCoefficient && options.EntropyCoefficient <= 1, nameof(options.EntropyCoefficient), "Must be in [0, 1].");
Host.CheckParam(0 <= options.GainConfidenceLevel && options.GainConfidenceLevel < 1, nameof(options.GainConfidenceLevel), "Must be in [0, 1).");
Host.CheckParam(0 < options.MaximumBinCountPerFeature, nameof(options.MaximumBinCountPerFeature), "Must be positive.");
Host.CheckParam(0 < options.NumberOfIterations, nameof(options.NumberOfIterations), "Must be positive.");
Host.CheckParam(0 < options.MinimumExampleCountPerLeaf, nameof(options.MinimumExampleCountPerLeaf), "Must be positive.");
GamTrainerOptions = options;
Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true);
_gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - GamTrainerOptions.GainConfidenceLevel) * 0.5), 2);
_entropyCoefficient = GamTrainerOptions.EntropyCoefficient * 1e-6;
InitializeThreads();
}
private protected void TrainBase(TrainContext context)
{
using (var ch = Host.Start("Training"))
{
ch.CheckValue(context, nameof(context));
// Create the datasets
ConvertData(context.TrainingSet, context.ValidationSet);
// Define scoring and testing
DefineScoreTrackers();
if (HasValidSet)
DefinePruningTest();
InputLength = context.TrainingSet.Schema.Feature.Value.Type.GetValueCount();
TrainCore(ch);
}
}
private void DefineScoreTrackers()
{
TrainSetScore = new ScoreTracker("train", TrainSet, null);
if (HasValidSet)
ValidSetScore = new ScoreTracker("valid", ValidSet, null);
}
private protected abstract void DefinePruningTest();
private protected abstract void CheckLabel(RoleMappedData data);
private void ConvertData(RoleMappedData trainData, RoleMappedData validationData)
{
trainData.CheckFeatureFloatVector();
trainData.CheckOptFloatWeight();
CheckLabel(trainData);
var useTranspose = UseTranspose(GamTrainerOptions.DiskTranspose, trainData);
var instanceConverter = new ExamplesToFastTreeBins(Host, GamTrainerOptions.MaximumBinCountPerFeature, useTranspose, !GamTrainerOptions.FeatureFlocks, GamTrainerOptions.MinimumExampleCountPerLeaf, float.PositiveInfinity);
ParallelTraining.InitEnvironment();
TrainSet = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, null, false);
FeatureMap = instanceConverter.FeatureMap;
if (validationData != null)
ValidSet = instanceConverter.GetCompatibleDataset(validationData, PredictionKind, null, false);
Host.Assert(FeatureMap == null || FeatureMap.Length == TrainSet.NumFeatures);
}
private bool UseTranspose(bool? useTranspose, RoleMappedData data)
{
Host.AssertValue(data);
Host.Assert(data.Schema.Feature.HasValue);
if (useTranspose.HasValue)
return useTranspose.Value;
return (data.Data as ITransposeDataView)?.GetSlotType(data.Schema.Feature.Value.Index) != null;
}
private void TrainCore(IChannel ch)
{
Contracts.CheckValue(ch, nameof(ch));
// REVIEW:Get rid of this lock then we completely remove all static classes from Gam such as BlockingThreadPool.
lock (FastTreeShared.TrainLock)
{
using (Timer.Time(TimerEvent.TotalInitialization))
Initialize(ch);
using (Timer.Time(TimerEvent.TotalTrain))
TrainMainEffectsModel(ch);
}
}
/// <summary>
/// Training algorithm for the single-feature functions f(x)
/// </summary>
/// <param name="ch">The channel to write to</param>
private void TrainMainEffectsModel(IChannel ch)
{
Contracts.AssertValue(ch);
int iterations = GamTrainerOptions.NumberOfIterations;
ch.Info("Starting to train ...");
using (var pch = Host.StartProgressChannel("GAM training"))
{
_objectiveFunction = CreateObjectiveFunction();
var sumWeights = HasWeights ? TrainSet.SampleWeights.Sum() : 0;
int iteration = 0;
pch.SetHeader(new ProgressHeader("iterations"), e => e.SetProgress(0, iteration, iterations));
for (int i = iteration; iteration < iterations; iteration++)
{
using (Timer.Time(TimerEvent.Iteration))
{
var gradient = _objectiveFunction.GetGradient(ch, TrainSetScore.Scores);
var sumTargets = gradient.Sum();
SumUpsAcrossFlocks(gradient, sumTargets, sumWeights);
TrainOnEachFeature(gradient, TrainSetScore.Scores, sumTargets, sumWeights, iteration);
UpdateScores(iteration);
}
}
}
CombineGraphs(ch);
}
private void SumUpsAcrossFlocks(double[] gradient, double sumTargets, double sumWeights)
{
var sumupTask = ThreadTaskManager.MakeTask(
(flockIndex) =>
{
_histogram[flockIndex].Sumup(
TrainSet.FlockToFirstFeature(flockIndex),
null,
TrainSet.NumDocs,
sumTargets,
sumWeights,
gradient,
TrainSet.SampleWeights,
null);
}, TrainSet.NumFlocks);
sumupTask.RunTask();
}
private void TrainOnEachFeature(double[] gradient, double[] scores, double sumTargets, double sumWeights, int iteration)
{
var trainTask = ThreadTaskManager.MakeTask(
(feature) =>
{
TrainingIteration(feature, gradient, scores, sumTargets, sumWeights, iteration);
}, TrainSet.NumFeatures);
trainTask.RunTask();
}
private void TrainingIteration(int globalFeatureIndex, double[] gradient, double[] scores,
double sumTargets, double sumWeights, int iteration)
{
int flockIndex;
int subFeatureIndex;
TrainSet.MapFeatureToFlockAndSubFeature(globalFeatureIndex, out flockIndex, out subFeatureIndex);
// Compute the split for the feature
_histogram[flockIndex].FindBestSplitForFeature(_leafSplitHelper, _leafSplitCandidates,
_leafSplitCandidates.Targets.Length, sumTargets, sumWeights,
globalFeatureIndex, flockIndex, subFeatureIndex, GamTrainerOptions.MinimumExampleCountPerLeaf, HasWeights,
_gainConfidenceInSquaredStandardDeviations, _entropyCoefficient,
TrainSet.Flocks[flockIndex].Trust(subFeatureIndex), 0);
// Adjust the model
if (_leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].Gain > 0)
ConvertTreeToGraph(globalFeatureIndex, iteration);
}
/// <summary>
/// Update scores for all tracked datasets
/// </summary>
private void UpdateScores(int iteration)
{
// Pass scores by reference to be updated and manually trigger the update callbacks
UpdateScoresForSet(TrainSet, TrainSetScore.Scores, iteration);
TrainSetScore.SendScoresUpdatedMessage();
if (HasValidSet)
{
UpdateScoresForSet(ValidSet, ValidSetScore.Scores, iteration);
ValidSetScore.SendScoresUpdatedMessage();
}
}
/// <summary>
/// Updates the scores for a dataset.
/// </summary>
/// <param name="dataset">The dataset to use.</param>
/// <param name="scores">The current scores for this dataset</param>
/// <param name="iteration">The iteration of the algorithm.
/// Used to look up the sub-graph to use to update the score.</param>
/// <returns></returns>
private void UpdateScoresForSet(Dataset dataset, double[] scores, int iteration)
{
DefineDocumentThreadBlocks(dataset.NumDocs, BlockingThreadPool.NumThreads, out int[] threadBlocks);
var updateTask = ThreadTaskManager.MakeTask(
(threadIndex) =>
{
int startIndexInclusive = threadBlocks[threadIndex];
int endIndexExclusive = threadBlocks[threadIndex + 1];
for (int featureIndex = 0; featureIndex < _subGraph.Splits.Length; featureIndex++)
{
var featureIndexer = dataset.GetIndexer(featureIndex);
for (int doc = startIndexInclusive; doc < endIndexExclusive; doc++)
{
if (featureIndexer[doc] <= _subGraph.Splits[featureIndex][iteration].SplitPoint)
scores[doc] += _subGraph.Splits[featureIndex][iteration].LteValue;
else
scores[doc] += _subGraph.Splits[featureIndex][iteration].GtValue;
}
}
}, BlockingThreadPool.NumThreads);
updateTask.RunTask();
}
/// <summary>
/// Combine the single-feature single-tree graphs to a single-feature model
/// </summary>
private void CombineGraphs(IChannel ch)
{
// Prune backwards to the best iteration
int bestIteration = GamTrainerOptions.NumberOfIterations;
if (GamTrainerOptions.EnablePruning && PruningTest != null)
{
ch.Info("Pruning");
var finalResult = PruningTest.ComputeTests().ToArray()[PruningLossIndex];
string lossFunctionName = finalResult.LossFunctionName;
double bestLoss = finalResult.FinalValue;
if (PruningTest != null)
{
bestIteration = PruningTest.BestIteration;
bestLoss = PruningTest.BestResult.FinalValue;
}
if (bestIteration != GamTrainerOptions.NumberOfIterations)
ch.Info($"Best Iteration ({lossFunctionName}): {bestIteration} @ {bestLoss:G6} (vs {GamTrainerOptions.NumberOfIterations} @ {finalResult.FinalValue:G6}).");
else
ch.Info("No pruning necessary. More iterations may be necessary.");
}
// Combine the graphs to compute the per-feature (binned) Effects
BinEffects = new double[TrainSet.NumFeatures][];
for (int featureIndex = 0; featureIndex < TrainSet.NumFeatures; featureIndex++)
{
TrainSet.MapFeatureToFlockAndSubFeature(featureIndex, out int flockIndex, out int subFeatureIndex);
int numOfBins = TrainSet.Flocks[flockIndex].BinCount(subFeatureIndex);
BinEffects[featureIndex] = new double[numOfBins];
for (int iteration = 0; iteration < bestIteration; iteration++)
{
var splitPoint = _subGraph.Splits[featureIndex][iteration].SplitPoint;
for (int bin = 0; bin <= splitPoint; bin++)
BinEffects[featureIndex][bin] += _subGraph.Splits[featureIndex][iteration].LteValue;
for (int bin = (int)splitPoint + 1; bin < numOfBins; bin++)
BinEffects[featureIndex][bin] += _subGraph.Splits[featureIndex][iteration].GtValue;
}
}
// Center the graph around zero
CenterGraph();
// Redefine the bins s.t. bins only mark changes in effects
CreateEfficientBinning();
}
/// <summary>
/// Distribute the documents into blocks to be computed on each thread
/// </summary>
/// <param name="numDocs">The number of documents in the dataset</param>
/// <param name="blocks">An array containing the starting point for each thread;
/// the next position is the exclusive ending point for the thread.</param>
/// <param name="numThreads">The number of threads used.</param>
private void DefineDocumentThreadBlocks(int numDocs, int numThreads, out int[] blocks)
{
int extras = numDocs % numThreads;
int documentsPerThread = numDocs / numThreads;
blocks = new int[numThreads + 1];
blocks[0] = 0;
for (int t = 0; t < extras; t++)
blocks[t + 1] = blocks[t] + documentsPerThread + 1;
for (int t = extras; t < numThreads; t++)
blocks[t + 1] = blocks[t] + documentsPerThread;
}
/// <summary>
/// Center the graph using the mean response per feature on the training set.
/// </summary>
private void CenterGraph()
{
// Define this once
DefineDocumentThreadBlocks(TrainSet.NumDocs, BlockingThreadPool.NumThreads, out int[] trainThreadBlocks);
// Compute the mean of each Effect
var meanEffects = new double[BinEffects.Length];
var updateTask = ThreadTaskManager.MakeTask(
(threadIndex) =>
{
int startIndexInclusive = trainThreadBlocks[threadIndex];
int endIndexExclusive = trainThreadBlocks[threadIndex + 1];
for (int featureIndex = 0; featureIndex < BinEffects.Length; featureIndex++)
{
var featureIndexer = TrainSet.GetIndexer(featureIndex);
for (int doc = startIndexInclusive; doc < endIndexExclusive; doc++)
{
var bin = featureIndexer[doc];
double totalEffect;
double newTotalEffect;
do
{
totalEffect = meanEffects[featureIndex];
newTotalEffect = totalEffect + BinEffects[featureIndex][bin];
} while (totalEffect !=
Interlocked.CompareExchange(ref meanEffects[featureIndex], newTotalEffect, totalEffect));
// Update the shared effect, being careful of threading
}
}
}, BlockingThreadPool.NumThreads);
updateTask.RunTask();
// Compute the intercept and center each graph
MeanEffect = 0.0;
for (int featureIndex = 0; featureIndex < BinEffects.Length; featureIndex++)
{
// Compute the mean effect
meanEffects[featureIndex] /= TrainSet.NumDocs;
// Shift the mean from the bins into the intercept
MeanEffect += meanEffects[featureIndex];
for (int bin = 0; bin < BinEffects[featureIndex].Length; ++bin)
BinEffects[featureIndex][bin] -= meanEffects[featureIndex];
}
}
/// <summary>
/// Process bins such that only bin upper bounds and bin effects remain where
/// the effect changes.
/// </summary>
private protected void CreateEfficientBinning()
{
BinUpperBounds = new double[TrainSet.NumFeatures][];
var newBinEffects = new List<double>();
var newBinBoundaries = new List<double>();
for (int i = 0; i < TrainSet.NumFeatures; i++)
{
TrainSet.MapFeatureToFlockAndSubFeature(i, out int flockIndex, out int subFeatureIndex);
double[] binUpperBound = TrainSet.Flocks[flockIndex].BinUpperBounds(subFeatureIndex);
double value = BinEffects[i][0];
for (int j = 0; j < BinEffects[i].Length; j++)
{
double element = BinEffects[i][j];
if (element != value)
{
newBinEffects.Add(value);
newBinBoundaries.Add(binUpperBound[j - 1]);
value = element;
}
}
// Catch the last value
newBinBoundaries.Add(binUpperBound[BinEffects[i].Length - 1]);
newBinEffects.Add(BinEffects[i][BinEffects[i].Length - 1]);
// Overwrite the old arrays with the efficient arrays
BinUpperBounds[i] = newBinBoundaries.ToArray();
BinEffects[i] = newBinEffects.ToArray();
newBinEffects.Clear();
newBinBoundaries.Clear();
}
}
private void ConvertTreeToGraph(int globalFeatureIndex, int iteration)
{
SplitInfo splitinfo = _leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex];
_subGraph.Splits[globalFeatureIndex][iteration].SplitPoint = splitinfo.Threshold;
_subGraph.Splits[globalFeatureIndex][iteration].LteValue = GamTrainerOptions.LearningRate * splitinfo.LteOutput;
_subGraph.Splits[globalFeatureIndex][iteration].GtValue = GamTrainerOptions.LearningRate * splitinfo.GTOutput;
}
private void InitializeGamHistograms()
{
_histogram = new SufficientStatsBase[TrainSet.Flocks.Length];
for (int i = 0; i < TrainSet.Flocks.Length; i++)
_histogram[i] = TrainSet.Flocks[i].CreateSufficientStats(HasWeights);
}
private void Initialize(IChannel ch)
{
using (Timer.Time(TimerEvent.InitializeTraining))
{
InitializeGamHistograms();
_subGraph = new SubGraph(TrainSet.NumFeatures, GamTrainerOptions.NumberOfIterations);
_leafSplitCandidates = new LeastSquaresRegressionTreeLearner.LeafSplitCandidates(TrainSet);
_leafSplitHelper = new LeafSplitHelper(HasWeights);
}
}
private void InitializeThreads()
{
ParallelTraining = new SingleTrainer();
ThreadTaskManager.Initialize(GamTrainerOptions.NumberOfThreads ?? Environment.ProcessorCount);
}
private protected abstract ObjectiveFunctionBase CreateObjectiveFunction();
private class LeafSplitHelper : ILeafSplitStatisticsCalculator
{
private readonly bool _hasWeights;
public LeafSplitHelper(bool hasWeights)
{
_hasWeights = hasWeights;
}
/// <summary>
/// Returns the split gain for a particular leaf. Used on two leaves to calculate
/// the squared error gain for a particular leaf.
/// </summary>
/// <param name="count">Number of documents in this leaf</param>
/// <param name="sumTargets">Sum of the target values for this leaf</param>
/// <param name="sumWeights">Sum of the weights for this leaf, not meaningful if
/// <see cref="HasWeights"/> is <c>false</c></param>
/// <returns>The gain in least squared error</returns>
public double GetLeafSplitGain(int count, double sumTargets, double sumWeights)
{
if (!_hasWeights)
return (sumTargets * sumTargets) / count;
return -4.0 * (Math.Abs(sumTargets) + sumWeights);
}
/// <summary>
/// Calculates the output value for a leaf after splitting.
/// </summary>
/// <param name="count">Number of documents in this leaf</param>
/// <param name="sumTargets">Sum of the target values for this leaf</param>
/// <param name="sumWeights">Sum of the weights for this leaf, not meaningful if
/// <see cref="HasWeights"/> is <c>false</c></param>
/// <returns>The output value for a leaf</returns>
public double CalculateSplittedLeafOutput(int count, double sumTargets, double sumWeights)
{
if (!_hasWeights)
return sumTargets / count;
Contracts.Assert(sumWeights != 0);
return sumTargets / sumWeights;
}
}
private struct SubGraph
{
public Stump[][] Splits;
public SubGraph(int numFeatures, int numIterations)
{
Splits = new Stump[numFeatures][];
for (int i = 0; i < numFeatures; ++i)
{
Splits[i] = new Stump[numIterations];
for (int j = 0; j < numIterations; j++)
Splits[i][j] = new Stump(0, 0, 0);
}
}
public struct Stump
{
public uint SplitPoint;
public double LteValue;
public double GtValue;
public Stump(uint splitPoint, double lteValue, double gtValue)
{
SplitPoint = splitPoint;
LteValue = lteValue;
GtValue = gtValue;
}
}
}
}
internal static class Gam
{
[TlcModule.EntryPoint(Name = "Trainers.GeneralizedAdditiveModelRegressor", Desc = GamRegressionTrainer.Summary, UserName = GamRegressionTrainer.UserNameValue, ShortName = GamRegressionTrainer.ShortName)]
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, GamRegressionTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainGAM");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return TrainerEntryPointsUtils.Train<GamRegressionTrainer.Options, CommonOutputs.RegressionOutput>(host, input,
() => new GamRegressionTrainer(host, input),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName));
}
[TlcModule.EntryPoint(Name = "Trainers.GeneralizedAdditiveModelBinaryClassifier", Desc = GamBinaryTrainer.Summary, UserName = GamBinaryTrainer.UserNameValue, ShortName = GamBinaryTrainer.ShortName)]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, GamBinaryTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainGAM");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return TrainerEntryPointsUtils.Train<GamBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new GamBinaryTrainer(host, input),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName));
}
}
internal static class GamDefaults
{
internal const int NumberOfIterations = 9500;
internal const int MaximumBinCountPerFeature = 255;
internal const double LearningRate = 0.002; // A small value
}
}
|