|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Trainers.FastTree
{
//An interface that can be implemented on
internal interface IFastTrainingScoresUpdate
{
ScoreTracker GetUpdatedTrainingScores();
}
internal abstract class OptimizationAlgorithm
{
//TODO: We should move Partitioning to OptimizationAlgorithm
public TreeLearner TreeLearner;
public ObjectiveFunctionBase ObjectiveFunction;
// This is added to signalize that we are just about to update all scores
// This is only used fof printing training graph scores that we can compute fast for the previous iteration saving topLables by scores from n+1 gradient computation
public delegate void PreScoreUpdateHandler(IChannel ch);
public PreScoreUpdateHandler PreScoreUpdateEvent;
public InternalTreeEnsemble Ensemble;
public ScoreTracker TrainingScores;
public List<ScoreTracker> TrackedScores;
public IStepSearch AdjustTreeOutputsOverride; // if set it overrides IStepSearch possibly implemented by ObejctiveFunctionBase
public double Smoothing;
public double DropoutRate;
public Random DropoutRng;
public bool UseFastTrainingScoresUpdate;
public OptimizationAlgorithm(InternalTreeEnsemble ensemble, Dataset trainData, double[] initTrainScores)
{
Ensemble = ensemble;
TrainingScores = ConstructScoreTracker("train", trainData, initTrainScores);
TrackedScores = new List<ScoreTracker>();
TrackedScores.Add(TrainingScores);
DropoutRng = new Random();
UseFastTrainingScoresUpdate = true;
}
public void SetTrainingData(Dataset trainData, double[] initTrainScores)
{
TrainingScores = ConstructScoreTracker("train", trainData, initTrainScores);
TrackedScores[0] = TrainingScores;
}
internal abstract InternalRegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures);
internal virtual void UpdateAllScores(IChannel ch, InternalRegressionTree tree)
{
if (PreScoreUpdateEvent != null)
PreScoreUpdateEvent(ch);
using (Timer.Time(TimerEvent.UpdateScores))
{
foreach (ScoreTracker t in TrackedScores)
UpdateScores(t, tree);
}
}
internal virtual void UpdateScores(ScoreTracker t, InternalRegressionTree tree)
{
if (t == TrainingScores)
{
IFastTrainingScoresUpdate fastUpdate = AdjustTreeOutputsOverride as IFastTrainingScoresUpdate;
ScoreTracker updatedScores = (UseFastTrainingScoresUpdate && fastUpdate != null) ? fastUpdate.GetUpdatedTrainingScores() : null;
if (updatedScores != null)
t.SetScores(updatedScores.Scores);
else
t.AddScores(tree, TreeLearner.Partitioning, 1.0);
}
else
t.AddScores(tree, 1.0);
}
public ScoreTracker GetScoreTracker(string name, Dataset set, double[] initScores)
{
//Fisrt check for duplicates maybe we already track scores for set dataset
foreach (var st in TrackedScores)
{
if (st.Dataset == set)
return st;
}
ScoreTracker newTracker = ConstructScoreTracker(name, set, initScores);
//add the constructed tracker to the list of scores we need to update
TrackedScores.Add(newTracker);
return newTracker;
}
protected abstract ScoreTracker ConstructScoreTracker(string name, Dataset set, double[] initScores);
/// <summary>
/// Regularize a regression tree with smoothing paramter alpha.
/// </summary>
protected virtual void SmoothTree(InternalRegressionTree tree, double smoothing)
{
if (smoothing == 0.0)
return;
//Create recursive structure of the tree starting from root node
var regularizer = new RecursiveRegressionTree(tree, TreeLearner.Partitioning, 0);
//Perform bottom-up computation of weighted interior node output
double rootNodeOutput = regularizer.GetWeightedOutput();
//followed by top-down propagation of parent's output value
regularizer.SmoothLeafOutputs(rootNodeOutput, smoothing);
}
public virtual void FinalizeLearning(int bestIteration)
{
if (bestIteration != Ensemble.NumTrees)
{
Ensemble.RemoveAfter(Math.Max(bestIteration, 0));
TrackedScores.Clear(); //Invalidate all precomputed scores as they are not valid anymore //slow method of score computation will be used instead
}
}
}
}
|