File: Optimizer\SgdOptimizer.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Numeric
{
    /// <summary>
    /// Delegate for functions that determine whether to terminate search. Called after each update.
    /// </summary>
    /// <param name="x">Current iterate</param>
    /// <returns>True if search should terminate</returns>
    internal delegate bool DTerminate(in VBuffer<float> x);
 
    /// <summary>
    /// Stochastic gradient descent with variations (minibatch, momentum, averaging).
    /// </summary>
    internal sealed class SgdOptimizer
    {
        private int _batchSize;
 
        /// <summary>
        /// Size of minibatches
        /// </summary>
        public int BatchSize
        {
            get { return _batchSize; }
            set
            {
                Contracts.Check(value > 0);
                _batchSize = value;
            }
        }
 
        private float _momentum;
 
        /// <summary>
        /// Momentum parameter
        /// </summary>
        public float Momentum
        {
            get { return _momentum; }
            set
            {
                Contracts.Check(0 <= value && value < 1);
                _momentum = value;
            }
        }
 
        private float _t0;
 
        /// <summary>
        /// Base of step size schedule s_t = 1 / (t0 + f(t))
        /// </summary>
        public float T0
        {
            get { return _t0; }
            set
            {
                Contracts.Check(value >= 0);
                _t0 = value;
            }
        }
 
        /// <summary>
        /// Termination criterion
        /// </summary>
        private readonly DTerminate _terminate;
 
        private bool _averaging;
 
        /// <summary>
        /// If true, iterates are averaged
        /// </summary>
        public bool Averaging
        {
            get { return _averaging; }
            set { _averaging = value; }
        }
 
        private RateScheduleType _rateSchedule;
 
        /// <summary>
        /// Gets/Sets rate schedule type
        /// </summary>
        public RateScheduleType RateSchedule
        {
            get { return _rateSchedule; }
            set { _rateSchedule = value; }
        }
 
        private int _maxSteps;
 
        /// <summary>
        /// Gets/Sets maximum number of steps. Set to 0 for no max
        /// </summary>
        public int MaxSteps
        {
            get { return _maxSteps; }
            set
            {
                Contracts.Check(value >= 0);
                _maxSteps = value;
            }
        }
 
        /// <summary>
        /// Annealing schedule for learning rate
        /// </summary>
        public enum RateScheduleType
        {
            /// <summary>
            /// r_t = 1 / t0
            /// </summary>
            Constant,
 
            /// <summary>
            /// r_t = 1 / (t0 + sqrt(t))
            /// </summary>
            Sqrt,
 
            /// <summary>
            /// r_t = 1 / (t0 + t)
            /// </summary>
            Linear
        }
 
        /// <summary>
        /// Creates SGDOptimizer and sets optimization parameters
        /// </summary>
        /// <param name="terminate">Termination criterion</param>
        /// <param name="rateSchedule">Annealing schedule type for learning rate</param>
        /// <param name="averaging">If true, all iterates are averaged</param>
        /// <param name="t0">Base for learning rate schedule</param>
        /// <param name="batchSize">Average this number of stochastic gradients for each update</param>
        /// <param name="momentum">Momentum parameter</param>
        /// <param name="maxSteps">Maximum number of updates (0 for no max)</param>
        public SgdOptimizer(DTerminate terminate, RateScheduleType rateSchedule = RateScheduleType.Sqrt, bool averaging = false, float t0 = 1, int batchSize = 1, float momentum = 0, int maxSteps = 0)
        {
            _terminate = terminate;
            _rateSchedule = rateSchedule;
            _averaging = averaging;
            _t0 = t0;
            _batchSize = batchSize;
            _momentum = momentum;
            _maxSteps = maxSteps;
        }
 
        /// <summary>
        /// Delegate for functions to query stochastic gradient at a point
        /// </summary>
        /// <param name="x">Point at which to evaluate</param>
        /// <param name="grad">Vector to be filled in with gradient</param>
        public delegate void DStochasticGradient(in VBuffer<float> x, ref VBuffer<float> grad);
 
        /// <summary>
        /// Minimize the function represented by <paramref name="f"/>.
        /// </summary>
        /// <param name="f">Stochastic gradients of function to minimize</param>
        /// <param name="initial">Initial point</param>
        /// <param name="result">Approximate minimum of <paramref name="f"/></param>
        public void Minimize(DStochasticGradient f, ref VBuffer<float> initial, ref VBuffer<float> result)
        {
            Contracts.Check(FloatUtils.IsFinite(initial.GetValues()), "The initial vector contains NaNs or infinite values.");
            int dim = initial.Length;
 
            VBuffer<float> grad = VBufferUtils.CreateEmpty<float>(dim);
            VBuffer<float> step = VBufferUtils.CreateEmpty<float>(dim);
            VBuffer<float> x = default(VBuffer<float>);
            initial.CopyTo(ref x);
            VBuffer<float> prev = default(VBuffer<float>);
            VBuffer<float> avg = VBufferUtils.CreateEmpty<float>(dim);
 
            for (int n = 0; _maxSteps == 0 || n < _maxSteps; ++n)
            {
                if (_momentum == 0)
                    VBufferUtils.Resize(ref step, step.Length, 0);
                else
                    VectorUtils.ScaleBy(ref step, _momentum);
 
                float stepSize;
                switch (_rateSchedule)
                {
                    case RateScheduleType.Constant:
                        stepSize = 1 / _t0;
                        break;
                    case RateScheduleType.Sqrt:
                        stepSize = 1 / (_t0 + MathUtils.Sqrt(n));
                        break;
                    case RateScheduleType.Linear:
                        stepSize = 1 / (_t0 + n);
                        break;
                    default:
                        throw Contracts.Except();
                }
 
                float scale = (1 - _momentum) / _batchSize;
                for (int i = 0; i < _batchSize; ++i)
                {
                    f(in x, ref grad);
                    VectorUtils.AddMult(in grad, scale, ref step);
                }
 
                if (_averaging)
                {
                    Utils.Swap(ref avg, ref prev);
                    VectorUtils.ScaleBy(prev, ref avg, (float)n / (n + 1));
                    VectorUtils.AddMult(in step, -stepSize, ref x);
                    VectorUtils.AddMult(in x, (float)1 / (n + 1), ref avg);
 
                    if ((n > 0 && TerminateTester.ShouldTerminate(in avg, in prev)) || _terminate(in avg))
                    {
                        result = avg;
                        return;
                    }
                }
                else
                {
                    Utils.Swap(ref x, ref prev);
                    VectorUtils.AddMult(in step, -stepSize, ref prev, ref x);
                    if ((n > 0 && TerminateTester.ShouldTerminate(in x, in prev)) || _terminate(in x))
                    {
                        result = x;
                        return;
                    }
                }
            }
 
            result = _averaging ? avg : x;
        }
    }
 
    /// <summary>
    /// Deterministic gradient descent with line search
    /// </summary>
    internal class GDOptimizer
    {
        /// <summary>
        /// Line search to use.
        /// </summary>
        public IDiffLineSearch LineSearch { get; set; }
 
        private int _maxSteps;
 
        /// <summary>
        /// Gets/Sets maximum number of steps. Set to 0 for no max.
        /// </summary>
        public int MaxSteps
        {
            get { return _maxSteps; }
            set
            {
                Contracts.Check(value >= 0);
                _maxSteps = value;
            }
        }
 
        /// <summary>
        /// Gets/sets termination criterion.
        /// </summary>
        public DTerminate Terminate { get; set; }
 
        /// <summary>
        /// Gets/sets whether to use nonlinear conjugate gradient.
        /// </summary>
        public bool UseCG { get; set; }
 
        /// <summary>
        /// Makes a new GDOptimizer with the given optimization parameters
        /// </summary>
        /// <param name="terminate">Termination criterion</param>
        /// <param name="lineSearch">Line search to use</param>
        /// <param name="maxSteps">Maximum number of updates</param>
        /// <param name="useCG">Use Cubic interpolation line search or Backtracking line search with Armijo condition</param>
        public GDOptimizer(DTerminate terminate, IDiffLineSearch lineSearch = null, bool useCG = false, int maxSteps = 0)
        {
            Terminate = terminate;
            if (LineSearch == null)
            {
                if (useCG)
                    LineSearch = new CubicInterpLineSearch((float)0.01);
                else
                    LineSearch = new BacktrackingLineSearch();
            }
            else
                LineSearch = lineSearch;
            _maxSteps = maxSteps;
            UseCG = useCG;
        }
 
        private class LineFunc
        {
            private readonly bool _useCG;
 
            private VBuffer<float> _point;
            private VBuffer<float> _newPoint;
            private VBuffer<float> _grad;
            private VBuffer<float> _newGrad;
            private VBuffer<float> _dir;
 
            public VBuffer<float> NewPoint => _newPoint;
 
            private float _value;
            private float _newValue;
 
            public float Value => _value;
 
            private readonly DifferentiableFunction _func;
 
            public float Deriv => VectorUtils.DotProduct(in _dir, in _grad);
 
            public LineFunc(DifferentiableFunction function, in VBuffer<float> initial, bool useCG = false)
            {
                int dim = initial.Length;
 
                initial.CopyTo(ref _point);
                _func = function;
                // REVIEW: plumb the IProgressChannelProvider through.
                _value = _func(in _point, ref _grad, null);
                VectorUtils.ScaleInto(in _grad, -1, ref _dir);
 
                _useCG = useCG;
            }
 
            public float Eval(float step, out float deriv)
            {
                VectorUtils.AddMultInto(in _point, step, in _dir, ref _newPoint);
                _newValue = _func(in _newPoint, ref _newGrad, null);
                deriv = VectorUtils.DotProduct(in _dir, in _newGrad);
                return _newValue;
            }
 
            public void ChangeDir()
            {
                if (_useCG)
                {
                    float newByNew = VectorUtils.NormSquared(_newGrad);
                    float newByOld = VectorUtils.DotProduct(in _newGrad, in _grad);
                    float oldByOld = VectorUtils.NormSquared(_grad);
                    float betaPR = (newByNew - newByOld) / oldByOld;
                    float beta = Math.Max(0, betaPR);
                    VectorUtils.ScaleBy(ref _dir, beta);
                    VectorUtils.AddMult(in _newGrad, -1, ref _dir);
                }
                else
                    VectorUtils.ScaleInto(in _newGrad, -1, ref _dir);
                _newPoint.CopyTo(ref _point);
                _newGrad.CopyTo(ref _grad);
                _value = _newValue;
            }
        }
 
        /// <summary>
        /// Finds approximate minimum of the function
        /// </summary>
        /// <param name="function">Function to minimize</param>
        /// <param name="initial">Initial point</param>
        /// <param name="result">Approximate minimum</param>
        public void Minimize(DifferentiableFunction function, in VBuffer<float> initial, ref VBuffer<float> result)
        {
            Contracts.Check(FloatUtils.IsFinite(initial.GetValues()), "The initial vector contains NaNs or infinite values.");
            LineFunc lineFunc = new LineFunc(function, in initial, UseCG);
            VBuffer<float> prev = default(VBuffer<float>);
            initial.CopyTo(ref prev);
 
            for (int n = 0; _maxSteps == 0 || n < _maxSteps; ++n)
            {
                float step = LineSearch.Minimize(lineFunc.Eval, lineFunc.Value, lineFunc.Deriv);
                var newPoint = lineFunc.NewPoint;
                bool terminateNow = n > 0 && TerminateTester.ShouldTerminate(in newPoint, in prev);
                if (terminateNow || Terminate(in newPoint))
                    break;
                newPoint.CopyTo(ref prev);
                lineFunc.ChangeDir();
            }
 
            lineFunc.NewPoint.CopyTo(ref result);
        }
    }
 
    /// <summary>
    /// Terminates the optimization if NA value appears in result or no progress is made.
    /// </summary>
    internal static class TerminateTester
    {
        /// <summary>
        /// Test whether the optimization should terminate. Returns true if x contains NA or +/-Inf or x equals xprev.
        /// </summary>
        /// <param name="x">The current value.</param>
        /// <param name="xprev">The value from the previous iteration.</param>
        /// <returns>True if the optimization routine should terminate at this iteration.</returns>
        internal static bool ShouldTerminate(in VBuffer<float> x, in VBuffer<float> xprev)
        {
            Contracts.Assert(x.Length == xprev.Length, "Vectors must have the same dimensionality.");
            Contracts.Assert(FloatUtils.IsFinite(xprev.GetValues()));
 
            var xValues = x.GetValues();
            if (!FloatUtils.IsFinite(xValues))
                return true;
 
            var xprevValues = xprev.GetValues();
            if (x.IsDense && xprev.IsDense)
            {
                for (int i = 0; i < xValues.Length; i++)
                {
                    if (xValues[i] != xprevValues[i])
                        return false;
                }
            }
            else if (xprev.IsDense)
            {
                var xIndices = x.GetIndices();
                int j = 0;
                for (int ii = 0; ii < xValues.Length; ii++)
                {
                    int i = xIndices[ii];
                    while (j < i)
                    {
                        if (xprevValues[j++] != 0)
                            return false;
                    }
                    Contracts.Assert(i == j);
                    if (xValues[ii] != xprevValues[j++])
                        return false;
                }
 
                while (j < xprevValues.Length)
                {
                    if (xprevValues[j++] != 0)
                        return false;
                }
            }
            else if (x.IsDense)
            {
                var xprevIndices = xprev.GetIndices();
                int i = 0;
                for (int jj = 0; jj < xprevValues.Length; jj++)
                {
                    int j = xprevIndices[jj];
                    while (i < j)
                    {
                        if (xValues[i++] != 0)
                            return false;
                    }
                    Contracts.Assert(j == i);
                    if (xValues[i++] != xprevValues[jj])
                        return false;
                }
 
                while (i < xValues.Length)
                {
                    if (xValues[i++] != 0)
                        return false;
                }
            }
            else
            {
                // Both sparse.
                var xIndices = x.GetIndices();
                var xprevIndices = xprev.GetIndices();
                int ii = 0;
                int jj = 0;
                while (ii < xValues.Length && jj < xprevValues.Length)
                {
                    int i = xIndices[ii];
                    int j = xprevIndices[jj];
                    if (i == j)
                    {
                        if (xValues[ii++] != xprevValues[jj++])
                            return false;
                    }
                    else if (i < j)
                    {
                        if (xValues[ii++] != 0)
                            return false;
                    }
                    else
                    {
                        if (xprevValues[jj++] != 0)
                            return false;
                    }
                }
 
                while (ii < xValues.Length)
                {
                    if (xValues[ii++] != 0)
                        return false;
                }
 
                while (jj < xprevValues.Length)
                {
                    if (xprevValues[jj++] != 0)
                        return false;
                }
            }
 
            return true;
        }
    }
}