File: Training\OptimizationAlgorithms\ConjugateGradientDescent.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    // Conjugate gradient descent
    internal class ConjugateGradientDescent : GradientDescent
    {
        private double[] _previousGradient;
        private double[] _currentGradient;
        private double[] _currentDk;
 
        public ConjugateGradientDescent(InternalTreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
            : base(ensemble, trainData, initTrainScores, gradientWrapper)
        {
            _currentDk = new double[trainData.NumDocs];
        }
 
        private protected override double[] GetGradient(IChannel ch)
        {
            Contracts.AssertValue(ch);
            _previousGradient = _currentGradient;
 
            _currentGradient = ObjectiveFunction.GetGradient(ch, TrainingScores.Scores);
            // We need to make a copy of gradient coz the reference returned is private structare of ObejctiveFunctionBase is valid only till next GetGradient call
            _currentGradient = (double[])_currentGradient.Clone();
 
            double[] previousDk = _currentDk;
 
            //First iteration
            if (_previousGradient == null)
                _previousGradient = _currentGradient;
#if !POLAK_RIBIERE_STEP
            // Compute Beta[k] = curG[k] * (curG[k] - prevG[k])
            // TODO: this can be optimized for speed. Keeping it slow but simple for now
            double beta = VectorUtils.GetDotProduct(_currentGradient, VectorUtils.Subtract(_currentGradient, _previousGradient)) / VectorUtils.GetDotProduct(_previousGradient, _previousGradient);
#else //Fletcher Reeves step
            // Compute Beta[k] = (curG[k]*cutG[k]) / (prevG[k] * prevG[k])
            double beta = VectorUtils.GetDotProduct(currentGradient, currentGradient) / VectorUtils.GetDotProduct(previousGradient, previousGradient);
#endif
            if (beta < 0)
                beta = 0;
 
            ch.Info("beta: {0}", beta);
            VectorUtils.MutiplyInPlace(previousDk, beta);
            VectorUtils.AddInPlace(previousDk, _currentGradient);
            _currentDk = previousDk; // Really no-op operation
 
            // We know that LeastSquaresRegressionTreeLearner does not destroy gradients so we can return our reference that we will need in next iter.
            if (TreeLearner is LeastSquaresRegressionTreeLearner)
                return _currentDk;
            // Assume that other treLearners destroy the gradient array so return a copy.
            else
                return (double[])_currentDk.Clone();
        }
    }
}