File: Training\OptimizationAlgorithms\GradientDescent.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    internal class GradientDescent : OptimizationAlgorithm
    {
        private readonly IGradientAdjuster _gradientWrapper;
 
        /// number of trees dropped in this iteration
        private int _numberOfDroppedTrees;
        // treeScores stores for every tree the predictions it makes on every training example. This is used
        // to eliminate the need for computing the scores when we drop trees. However, it causes a horrifying
        // memory drain.
        private readonly List<double[]> _treeScores;
        private double[] _droppedScores;
        private double[] _scores;
 
        internal GradientDescent(InternalTreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
            : base(ensemble, trainData, initTrainScores)
        {
            _gradientWrapper = gradientWrapper;
            _treeScores = new List<double[]>();
        }
 
        protected override ScoreTracker ConstructScoreTracker(string name, Dataset set, double[] initScores)
        {
            return new ScoreTracker(name, set, initScores);
        }
 
        private protected virtual double[] GetGradient(IChannel ch)
        {
            Contracts.AssertValue(ch);
            if (DropoutRate > 0)
            {
                if (_droppedScores == null)
                    _droppedScores = new double[TrainingScores.Scores.Length];
                else
                    Array.Clear(_droppedScores, 0, _droppedScores.Length);
                if (_scores == null)
                    _scores = new double[TrainingScores.Scores.Length];
                int numberOfTrees = Ensemble.NumTrees;
                int[] droppedTrees =
                    Enumerable.Range(0, numberOfTrees).Where(t => (DropoutRng.NextDouble() < DropoutRate)).ToArray();
                _numberOfDroppedTrees = droppedTrees.Length;
                if ((_numberOfDroppedTrees == 0) && (numberOfTrees > 0))
                {
                    droppedTrees = new int[] { DropoutRng.Next(numberOfTrees) };
                    // force at least a single tree to be dropped
                    _numberOfDroppedTrees = droppedTrees.Length;
                }
                ch.Trace("dropout: Dropping {0} trees of {1} for rate {2}",
                    _numberOfDroppedTrees, numberOfTrees, DropoutRate);
                foreach (int i in droppedTrees)
                {
                    double[] s = _treeScores[i];
                    for (int j = 0; j < _droppedScores.Length; j++)
                    {
                        _droppedScores[j] += s[j]; // summing up the weights of the dropped tree
                        s[j] *= _numberOfDroppedTrees / (1.0 + _numberOfDroppedTrees); // rescaling the dropped tree
                    }
                    Ensemble.GetTreeAt(i).ScaleOutputsBy(_numberOfDroppedTrees / (1.0 + _numberOfDroppedTrees));
                }
                for (int j = 0; j < _scores.Length; j++)
                {
                    _scores[j] = TrainingScores.Scores[j] - _droppedScores[j];
                    TrainingScores.Scores[j] -= _droppedScores[j] / (1.0 + _numberOfDroppedTrees);
                }
                return ObjectiveFunction.GetGradient(ch, _scores);
            }
            else
                return ObjectiveFunction.GetGradient(ch, TrainingScores.Scores);
        }
 
        protected virtual double[] AdjustTargetsAndSetWeights(IChannel ch)
        {
            if (_gradientWrapper == null)
                return GetGradient(ch);
            else
            {
                double[] targetWeights = null;
                double[] targets = _gradientWrapper.AdjustTargetAndSetWeights(GetGradient(ch), ObjectiveFunction, out targetWeights);
                Dataset.DatasetSkeleton dsSkeleton = TrainingScores.Dataset.Skeleton;
                return targets;
            }
        }
 
        internal override InternalRegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
        {
            Contracts.CheckValue(ch, nameof(ch));
            // Fit a regression tree to the gradient using least squares.
            InternalRegressionTree tree = TreeLearner.FitTargets(ch, activeFeatures, AdjustTargetsAndSetWeights(ch));
            if (tree == null)
                return null; // Could not learn a tree. Exit.
 
            // Adjust output values of tree by performing a Newton step.
 
            // REVIEW: This should be part of OptimizingAlgorithm.
            using (Timer.Time(TimerEvent.TreeLearnerAdjustTreeOutputs))
            {
                double[] backupScores = null;
                // when doing dropouts we need to replace the TrainingScores with the scores without the dropped trees
                if (DropoutRate > 0)
                {
                    backupScores = TrainingScores.Scores;
                    TrainingScores.Scores = _scores;
                }
 
                if (AdjustTreeOutputsOverride != null)
                    AdjustTreeOutputsOverride.AdjustTreeOutputs(ch, tree, TreeLearner.Partitioning, TrainingScores);
                else if (ObjectiveFunction is IStepSearch)
                    (ObjectiveFunction as IStepSearch).AdjustTreeOutputs(ch, tree, TreeLearner.Partitioning, TrainingScores);
                else
                    throw ch.Except("No AdjustTreeOutputs defined. Objective function should define IStepSearch or AdjustTreeOutputsOverride should be set");
                if (DropoutRate > 0)
                {
                    // Returning the original scores.
                    TrainingScores.Scores = backupScores;
                }
            }
            if (Smoothing != 0.0)
            {
                SmoothTree(tree, Smoothing);
                UseFastTrainingScoresUpdate = false;
            }
            if (DropoutRate > 0)
            {
                // Don't do shrinkage if you do dropouts.
                double scaling = (1.0 / (1.0 + _numberOfDroppedTrees));
                tree.ScaleOutputsBy(scaling);
                _treeScores.Add(tree.GetOutputs(TrainingScores.Dataset));
            }
            UpdateAllScores(ch, tree);
            Ensemble.AddTree(tree);
            return tree;
        }
    }
 
    /// <summary>
    /// Interface for wrapping with weights of gradient target values
    /// </summary>
    internal interface IGradientAdjuster
    {
        /// <summary>
        /// Create wrapping of gradient target values
        /// </summary>
        /// <param name="gradient"></param>
        /// <param name="objFunction">Objective functions can be used for constructing weights</param>
        /// <param name="targetWeights"></param>
        double[] AdjustTargetAndSetWeights(double[] gradient, ObjectiveFunctionBase objFunction, out double[] targetWeights);
    }
}