File: Optimizer\OptimizationMonitor.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Numeric
{
    /// <summary>
    /// An object which is used to decide whether to stop optimization.
    /// </summary>
    internal interface ITerminationCriterion
    {
        /// <summary>
        /// Name appropriate for display to the user.
        /// </summary>
        string FriendlyName { get; }
 
        /// <summary>
        /// Determines whether to stop optimization
        /// </summary>
        /// <param name="state">the state of the optimizer</param>
        /// <param name="message">a message to be printed (or null for no message)</param>
        /// <returns>true iff criterion is met, i.e. optimization should halt</returns>
        bool Terminate(Optimizer.OptimizerState state, out string message);
 
        /// <summary>
        /// Prepares the ITerminationCriterion for a new round of optimization
        /// </summary>
        void Reset();
    }
 
    /// <summary>
    /// A wrapper for a termination criterion that checks the gradient at a specified interval
    /// </summary>
    internal sealed class GradientCheckingMonitor : ITerminationCriterion
    {
        private const string _checkingMessage = "  Checking gradient...";
        private readonly ITerminationCriterion _termCrit;
        private readonly int _gradCheckInterval;
        // Reusable vectors utilized by the gradient tester.
        private VBuffer<float> _newGrad;
        private VBuffer<float> _newX;
 
        /// <summary>
        /// Initializes a new instance of the <see cref="GradientCheckingMonitor"/> class.
        /// </summary>
        /// <param name="termCrit">The termination criterion</param>
        /// <param name="gradientCheckInterval">The gradient check interval.</param>
        public GradientCheckingMonitor(ITerminationCriterion termCrit, int gradientCheckInterval)
        {
            Contracts.CheckParam(gradientCheckInterval > 0, nameof(gradientCheckInterval),
                "gradientCheckInterval must be positive.");
 
            _termCrit = termCrit;
            _gradCheckInterval = gradientCheckInterval;
        }
 
        public string FriendlyName { get { return "Gradient Checking Monitor wrapping " + _termCrit.FriendlyName; } }
 
        /// <summary>
        /// Determines whether to stop optimization
        /// </summary>
        /// <param name="state">the state of the optimizer</param>
        /// <param name="message">a message to be printed (or null for no message)</param>
        /// <returns>
        /// true iff criterion is met, i.e. optimization should halt
        /// </returns>
        public bool Terminate(Optimizer.OptimizerState state, out string message)
        {
            bool terminate = _termCrit.Terminate(state, out message);
 
            if (terminate || state.Iter % _gradCheckInterval == 1)
                message += string.Format("  GradCheck: {0,0:0.0000e0}", Check(state));
 
            return terminate;
        }
 
        private float Check(Optimizer.OptimizerState state)
        {
            Console.Error.Write(_checkingMessage);
            Console.Error.Flush();
            VBuffer<float> x = state.X;
            var lastDir = state.LastDir;
            float checkResult = GradientTester.Test(state.Function, in x, ref lastDir, true, ref _newGrad, ref _newX);
            for (int i = 0; i < _checkingMessage.Length; i++)
                Console.Error.Write('\b');
            return checkResult;
        }
 
        /// <summary>
        /// Prepares the ITerminationCriterion for a new round of optimization
        /// </summary>
        public void Reset()
        {
            _termCrit.Reset();
        }
    }
 
    /// <summary>
    /// An abstract partial implementation of ITerminationCriterion for those which do not require resetting
    /// </summary>
    internal abstract class StaticTerminationCriterion : ITerminationCriterion
    {
        public abstract string FriendlyName { get; }
 
        /// <summary>
        /// Determines whether to stop optimization
        /// </summary>
        /// <param name="state">the state of the optimizer</param>
        /// <param name="message">a message to be printed (or null for no message)</param>
        /// <returns>
        /// true iff criterion is met, i.e. optimization should halt
        /// </returns>
        public abstract bool Terminate(Optimizer.OptimizerState state, out string message);
 
        /// <summary>
        /// Prepares the ITerminationCriterion for a new round of optimization
        /// </summary>
        public void Reset() { }
    }
 
    /// <summary>
    /// Terminates when the geometrically-weighted average improvement falls below the tolerance
    /// </summary>
    internal sealed class MeanImprovementCriterion : ITerminationCriterion
    {
        private readonly float _tol;
        private readonly float _lambda;
        private readonly int _maxIterations;
        private float _unnormMeanImprovement;
 
        /// <summary>
        /// Initializes a new instance of the <see cref="MeanImprovementCriterion"/> class.
        /// </summary>
        /// <param name="tol">The tolerance parameter</param>
        /// <param name="lambda">The geometric weighting factor. Higher means more heavily weighted toward older values.</param>
        /// <param name="maxIterations">Maximum amount of iteration</param>
        public MeanImprovementCriterion(float tol = (float)1e-4, float lambda = (float)0.5, int maxIterations = int.MaxValue)
        {
            _tol = tol;
            _lambda = lambda;
            _maxIterations = maxIterations;
        }
 
        /// <summary>
        /// When criterion drops below this value, optimization is terminated
        /// </summary>
        public float Tolerance
        {
            get { return _tol; }
        }
 
        public string FriendlyName { get { return "Mean Improvement"; } }
 
        /// <summary>
        /// Determines whether to stop optimization
        /// </summary>
        /// <param name="state">the state of the optimizer</param>
        /// <param name="message">a message to be printed (or null for no message)</param>
        /// <returns>
        /// true iff criterion is met, i.e. optimization should halt
        /// </returns>
        public bool Terminate(Optimizer.OptimizerState state, out string message)
        {
            _unnormMeanImprovement = (state.LastValue - state.Value) + _lambda * _unnormMeanImprovement;
 
            float crit = _unnormMeanImprovement * (1 - _lambda) / (1 - MathUtils.Pow(_lambda, state.Iter));
            message = string.Format("{0:0.000e0}"{0:0.000e0}", crit);
            return (crit < _tol || state.Iter >= _maxIterations);
        }
 
        /// <summary>
        /// Prepares the ITerminationCriterion for a new round of optimization
        /// </summary>
        public void Reset()
        {
            _unnormMeanImprovement = 0;
        }
    }
 
    /// <summary>
    /// Stops optimization when the average objective improvement over the last
    /// n iterations, normalized by the function value, is small enough.
    /// </summary>
    /// <remarks>
    /// Inappropriate for functions whose optimal value is non-positive, because of normalization
    /// </remarks>
    internal sealed class MeanRelativeImprovementCriterion : ITerminationCriterion
    {
        private readonly int _n;
        private readonly float _tol;
        private readonly int _maxIterations;
        private readonly Queue<float> _pastValues;
 
        /// <summary>
        /// When criterion drops below this value, optimization is terminated
        /// </summary>
        public float Tolerance
        {
            get { return _tol; }
        }
 
        /// <summary>
        /// Number of previous iterations to store
        /// </summary>
        public int Iters
        {
            get { return _n; }
        }
 
        /// <summary>
        /// Create a MeanRelativeImprovementCriterion
        /// </summary>
        /// <param name="tol">tolerance level</param>
        /// <param name="n">number of past iterations to average over</param>
        /// <param name="maxIterations">Maximum amount of iteration</param>
        public MeanRelativeImprovementCriterion(float tol = (float)1e-4, int n = 5, int maxIterations = int.MaxValue)
        {
            _tol = tol;
            _n = n;
            _maxIterations = maxIterations;
            _pastValues = new Queue<float>(n);
        }
 
        public string FriendlyName { get { return ToString(); } }
 
        /// <summary>
        /// Returns true if the average objective improvement over the last
        /// n iterations, normalized by the function value, is less than the tolerance
        /// </summary>
        /// <param name="state">current state of the optimizer</param>
        /// <param name="message">the current value of the criterion</param>
        /// <returns>true if criterion is less than tolerance</returns>
        public bool Terminate(Optimizer.OptimizerState state, out string message)
        {
            float value = state.Value;
 
            if (_pastValues.Count < _n)
            {
                _pastValues.Enqueue(value);
                message = "...";
                return false;
            }
 
            float avgImprovement = (_pastValues.Dequeue() - value) / _n;
            _pastValues.Enqueue(value);
            float val = avgImprovement / Math.Abs(value);
            message = string.Format("{0,0:0.0000e0}", val);
            return (val < _tol || state.Iter >= _maxIterations);
        }
 
        /// <summary>
        /// String summary of criterion
        /// </summary>
        /// <returns>summary of criterion</returns>
        public override string ToString()
        {
            return string.Format("Mean rel impr over {0} iter'ns < tol: {1,0:0.000e0}", _n, _tol);
        }
 
        /// <summary>
        /// Prepares the ITerminationCriterion for a new round of optimization
        /// </summary>
        public void Reset()
        {
            _pastValues.Clear();
        }
    }
 
    /// <summary>
    /// Uses the gradient to determine an upper bound on (relative) distance from the optimum.
    /// </summary>
    /// <remarks>
    /// Works if the objective uses L2 prior (or in general if the hessian H is such
    /// that H > (1 / sigmaSq) * I at all points)
    /// Inappropriate for functions whose optimal value is non-positive, because of normalization
    /// </remarks>
    internal sealed class UpperBoundOnDistanceWithL2 : StaticTerminationCriterion
    {
        private readonly float _sigmaSq;
        private readonly float _tol;
        private float _bestBoundOnMin;
 
        /// <summary>
        /// When criterion drops below this value, optimization is terminated
        /// </summary>
        public float Tolerance
        {
            get { return _tol; }
        }
 
        /// <summary>
        /// Create termination criterion with supplied value of sigmaSq and tolerance
        /// </summary>
        /// <param name="sigmaSq">value of sigmaSq in L2 regularizer</param>
        /// <param name="tol">tolerance level</param>
        public UpperBoundOnDistanceWithL2(float sigmaSq = 1, float tol = (float)1e-2)
        {
            _sigmaSq = sigmaSq;
            _tol = tol;
 
            // REVIEW: Why shouldn't this be "Reset"?
            _bestBoundOnMin = float.NegativeInfinity;
        }
 
        public override string FriendlyName { get { return ToString(); } }
 
        /// <summary>
        /// Returns true if the proved bound on the distance from the optimum,
        /// normalized by the function value, is less than the tolerance
        /// </summary>
        /// <param name="state">current state of the optimizer</param>
        /// <param name="message">value of criterion</param>
        /// <returns>true if criterion is less than tolerance</returns>
        public override bool Terminate(Optimizer.OptimizerState state, out string message)
        {
            var gradient = state.Grad;
            float gradientNormSquared = VectorUtils.NormSquared(gradient);
            float value = state.Value;
            float newBoundOnMin = value - (float)0.5 * _sigmaSq * gradientNormSquared;
            _bestBoundOnMin = Math.Max(_bestBoundOnMin, newBoundOnMin);
            float val = (value - _bestBoundOnMin) / Math.Abs(value);
            message = string.Format("{0,0:0.0000e0}", val);
            return (val < _tol);
        }
 
        /// <summary>
        /// String summary of criterion
        /// </summary>
        /// <returns>summary of criterion</returns>
        public override string ToString()
        {
            return string.Format("UB rel dist from opt, σ² = {0,0:0.00e0}, tol = {1,0:0.00e0}", _sigmaSq, _tol);
        }
    }
 
    /// <summary>
    /// Criterion based on the norm of the gradient being small enough
    /// </summary>
    /// <remarks>
    /// Inappropriate for functions whose optimal value is non-positive, because of normalization
    /// </remarks>
    internal sealed class RelativeNormGradient : StaticTerminationCriterion
    {
        private readonly float _tol;
 
        /// <summary>
        /// When criterion drops below this value, optimization is terminated
        /// </summary>
        public float Tolerance
        {
            get { return _tol; }
        }
 
        /// <summary>
        /// Create a RelativeNormGradient with the supplied tolerance
        /// </summary>
        /// <param name="tol">tolerance level</param>
        public RelativeNormGradient(float tol = (float)1e-4)
        {
            _tol = tol;
        }
 
        public override string FriendlyName { get { return ToString(); } }
 
        /// <summary>
        /// Returns true if the norm of the gradient, divided by the value, is less than the tolerance.
        /// </summary>
        /// <param name="state">current state of the optimizer</param>
        /// <param name="message">the current value of the criterion</param>
        /// <returns>true iff criterion is less than the tolerance</returns>
        public override bool Terminate(Optimizer.OptimizerState state, out string message)
        {
            var grad = state.Grad;
            float norm = VectorUtils.Norm(grad);
            float val = norm / Math.Abs(state.Value);
            message = string.Format("{0,0:0.0000e0}", val);
            return val < _tol;
        }
 
        /// <summary>
        /// String summary of criterion
        /// </summary>
        /// <returns>summary of criterion</returns>
        public override string ToString()
        {
            return string.Format("Norm of grad / value < {0,0:0.00e0}", _tol);
        }
    }
}