File: Standard\Online\AveragedLinear.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.SearchSpace;
 
// TODO: Check if it works properly if Averaged is set to false
 
namespace Microsoft.ML.Trainers
{
    /// <summary>
    /// Arguments class for averaged linear trainers.
    /// </summary>
    public abstract class AveragedLinearOptions : OnlineLinearOptions
    {
        /// <summary>
        /// Learning rate.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", SortOrder = 50)]
        [TGUI(Label = "Learning rate", SuggestedSweeps = "0.01,0.1,0.5,1.0")]
        [TlcModule.SweepableDiscreteParam("LearningRate", new object[] { 0.01, 0.1, 0.5, 1.0 })]
        [Range(1e-4f, 1f, 1f, true)]
        public float LearningRate = AveragedDefault.LearningRate;
 
        /// <summary>
        /// Determine whether to decrease the <see cref="LearningRate"/> or not.
        /// </summary>
        /// <value>
        /// <see langword="true" /> to decrease the <see cref="LearningRate"/> as iterations progress; otherwise, <see langword="false" />.
        /// Default is <see langword="false" />. The learning rate will be reduced with every weight update proportional to the square root of the number of updates.
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Decrease learning rate", ShortName = "decreaselr", SortOrder = 50)]
        [TGUI(Label = "Decrease Learning Rate", Description = "Decrease learning rate as iterations progress")]
        [TlcModule.SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true })]
        [BooleanChoice]
        public bool DecreaseLearningRate = AveragedDefault.DecreaseLearningRate;
 
        /// <summary>
        /// Number of examples after which weights will be reset to the current average.
        /// </summary>
        /// <value>
        /// Default is <see langword="null" />, which disables this feature.
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Number of examples after which weights will be reset to the current average", ShortName = "numreset")]
        public long? ResetWeightsAfterXExamples = null;
 
        /// <summary>
        /// Determines when to update averaged weights.
        /// </summary>
        /// <value>
        /// <see langword="true" /> to update averaged weights only when loss is nonzero.
        /// <see langword="false" /> to update averaged weights on every example.
        /// Default is <see langword="true" />.
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Instead of updating averaged weights on every example, only update when loss is nonzero", ShortName = "lazy,DoLazyUpdates")]
        public bool LazyUpdate = true;
 
        /// <summary>
        /// The L2 weight for <a href='https://en.wikipedia.org/wiki/Regularization_(mathematics)'>regularization</a>.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization Weight", ShortName = "reg,L2RegularizerWeight", SortOrder = 50)]
        [TGUI(Label = "L2 Regularization Weight")]
        [TlcModule.SweepableFloatParam("L2RegularizerWeight", 0.0f, 0.4f)]
        [Range(0f, 32768f, 0f, false)]
        public float L2Regularization = AveragedDefault.L2Regularization;
 
        /// <summary>
        /// Extra weight given to more recent updates.
        /// </summary>
        /// <value>
        /// Default is 0, i.e. no extra gain.
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Extra weight given to more recent updates", ShortName = "rg")]
        public float RecencyGain = 0;
 
        /// <summary>
        /// Determines whether <see cref="RecencyGain"/> is multiplicative or additive.
        /// </summary>
        /// <value>
        /// <see langword="true" /> means <see cref="RecencyGain"/> is multiplicative.
        /// <see langword="false" /> means <see cref="RecencyGain"/> is additive.
        /// Default is <see langword="false" />.
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Whether Recency Gain is multiplicative (vs. additive)", ShortName = "rgm,RecencyGainMulti")]
        public bool RecencyGainMultiplicative = false;
 
        /// <summary>
        /// Determines whether to do averaging or not.
        /// </summary>
        /// <value>
        /// <see langword="true" /> to do averaging; otherwise, <see langword="false" />.
        /// Default is <see langword="true" />.
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Do averaging?", ShortName = "avg")]
        public bool Averaged = true;
 
        /// <summary>
        /// The inexactness tolerance for averaging.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "The inexactness tolerance for averaging", ShortName = "avgtol")]
        internal float AveragedTolerance = (float)1e-2;
 
        [BestFriend]
        internal class AveragedDefault : OnlineLinearOptions.OnlineDefault
        {
            public const float LearningRate = 1;
            public const bool DecreaseLearningRate = false;
            public const float L2Regularization = 0;
        }
 
        internal abstract IComponentFactory<IScalarLoss> LossFunctionFactory { get; }
    }
 
    /// <summary>
    /// Base class for averaged linear trainers.
    /// </summary>
    public abstract class AveragedLinearTrainer<TTransformer, TModel> : OnlineLinearTrainer<TTransformer, TModel>
        where TTransformer : ISingleFeaturePredictionTransformer<TModel>
        where TModel : class
    {
        private protected readonly AveragedLinearOptions AveragedLinearTrainerOptions;
        private protected IScalarLoss LossFunction;
 
        private protected abstract class AveragedTrainStateBase : TrainStateBase
        {
            protected float Gain;
 
            protected int NumNoUpdates;
 
            // For computing averaged weights and bias (if needed)
            protected VBuffer<float> TotalWeights;
            protected float TotalBias;
            protected double NumWeightUpdates;
 
            // The accumulated gradient of loss against gradient for all updates so far in the
            // totalled model, versus those pending in the weight vector that have not yet been
            // added to the total model.
            protected double TotalMultipliers;
            protected double PendingMultipliers;
 
            protected readonly bool Averaged;
            private readonly long _resetWeightsAfterXExamples;
            private readonly AveragedLinearOptions _args;
            private readonly IScalarLoss _loss;
 
            private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, AveragedLinearTrainer<TTransformer, TModel> parent)
                : base(ch, numFeatures, predictor, parent)
            {
                // Do the other initializations by setting the setters as if user had set them
                // Initialize the averaged weights if needed (i.e., do what happens when Averaged is set)
                Averaged = parent.AveragedLinearTrainerOptions.Averaged;
                if (Averaged)
                {
                    if (parent.AveragedLinearTrainerOptions.AveragedTolerance > 0)
                        VBufferUtils.Densify(ref Weights);
                    Weights.CopyTo(ref TotalWeights);
                }
                else
                {
                    // It is definitely advantageous to keep weights dense if we aren't adding them
                    // to another vector with each update.
                    VBufferUtils.Densify(ref Weights);
                }
                _resetWeightsAfterXExamples = parent.AveragedLinearTrainerOptions.ResetWeightsAfterXExamples ?? 0;
                _args = parent.AveragedLinearTrainerOptions;
                _loss = parent.LossFunction;
 
                Gain = 1;
            }
 
            /// <summary>
            /// Return the raw margin from the decision hyperplane
            /// </summary>
            public float AveragedMargin(in VBuffer<float> feat)
            {
                Contracts.Assert(Averaged);
                return (TotalBias + VectorUtils.DotProduct(in feat, in TotalWeights)) / (float)NumWeightUpdates;
            }
 
            public override float Margin(in VBuffer<float> feat)
                => Averaged ? AveragedMargin(in feat) : CurrentMargin(in feat);
 
            public override void FinishIteration(IChannel ch)
            {
                // Finalize things
                if (Averaged)
                {
                    if (_args.LazyUpdate && NumNoUpdates > 0)
                    {
                        // Update the total weights to include the final loss=0 updates
                        VectorUtils.AddMult(in Weights, NumNoUpdates * WeightsScale, ref TotalWeights);
                        TotalBias += Bias * NumNoUpdates;
                        NumWeightUpdates += NumNoUpdates;
                        NumNoUpdates = 0;
                        TotalMultipliers += PendingMultipliers;
                        PendingMultipliers = 0;
                    }
 
                    // reset the weights to averages if needed
                    if (_args.ResetWeightsAfterXExamples == 0)
                    {
                        ch.Info("Resetting weights to average weights");
                        VectorUtils.ScaleInto(in TotalWeights, 1 / (float)NumWeightUpdates, ref Weights);
                        WeightsScale = 1;
                        Bias = TotalBias / (float)NumWeightUpdates;
                    }
                }
 
                base.FinishIteration(ch);
            }
 
            public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, float label, float weight)
            {
                base.ProcessDataInstance(ch, in feat, label, weight);
 
                // compute the update and update if needed
                float output = CurrentMargin(in feat);
                Double loss = _loss.Loss(output, label);
 
                // REVIEW: Should this be biasUpdate != 0?
                // This loss does not incorporate L2 if present, but the chance of that addition to the loss
                // exactly cancelling out loss is remote.
                if (loss != 0 || _args.L2Regularization > 0)
                {
                    // If doing lazy weights, we need to update the totalWeights and totalBias before updating weights/bias
                    if (_args.LazyUpdate && _args.Averaged && NumNoUpdates > 0 && TotalMultipliers * _args.AveragedTolerance <= PendingMultipliers)
                    {
                        VectorUtils.AddMult(in Weights, NumNoUpdates * WeightsScale, ref TotalWeights);
                        TotalBias += Bias * NumNoUpdates * WeightsScale;
                        NumWeightUpdates += NumNoUpdates;
                        NumNoUpdates = 0;
                        TotalMultipliers += PendingMultipliers;
                        PendingMultipliers = 0;
                    }
 
                    // Make final adjustments to update parameters.
                    float rate = _args.LearningRate;
                    if (_args.DecreaseLearningRate)
                        rate /= MathUtils.Sqrt((float)NumWeightUpdates + NumNoUpdates + 1);
                    float biasUpdate = -rate * _loss.Derivative(output, label);
 
                    // Perform the update to weights and bias.
                    VectorUtils.AddMult(in feat, biasUpdate / WeightsScale, ref Weights);
                    WeightsScale *= 1 - 2 * _args.L2Regularization; // L2 regularization.
                    ScaleWeightsIfNeeded();
                    Bias += biasUpdate;
                    PendingMultipliers += Math.Abs(biasUpdate);
                }
 
                // Add to averaged weights and increment the count.
                if (Averaged)
                {
                    if (!_args.LazyUpdate)
                        IncrementAverageNonLazy();
                    else
                        NumNoUpdates++;
 
                    // Reset the weights to averages if needed.
                    if (_resetWeightsAfterXExamples > 0 && NumIterExamples % _resetWeightsAfterXExamples == 0)
                    {
                        ch.Info("Resetting weights to average weights");
                        VectorUtils.ScaleInto(in TotalWeights, 1 / (float)NumWeightUpdates, ref Weights);
                        WeightsScale = 1;
                        Bias = TotalBias / (float)NumWeightUpdates;
                    }
                }
            }
 
            /// <summary>
            /// Add current weights and bias to average weights/bias.
            /// </summary>
            private void IncrementAverageNonLazy()
            {
                if (_args.RecencyGain == 0)
                {
                    VectorUtils.AddMult(in Weights, WeightsScale, ref TotalWeights);
                    TotalBias += Bias;
                    NumWeightUpdates++;
                    return;
                }
                VectorUtils.AddMult(in Weights, Gain * WeightsScale, ref TotalWeights);
                TotalBias += Gain * Bias;
                NumWeightUpdates += Gain;
                Gain = (_args.RecencyGainMultiplicative ? Gain * _args.RecencyGain : Gain + _args.RecencyGain);
 
                // If gains got too big, rescale!
                if (Gain > 1000)
                {
                    const float scale = (float)1e-6;
                    Gain *= scale;
                    TotalBias *= scale;
                    VectorUtils.ScaleBy(ref TotalWeights, scale);
                    NumWeightUpdates *= scale;
                }
            }
        }
 
        private protected AveragedLinearTrainer(AveragedLinearOptions options, IHostEnvironment env, string name, SchemaShape.Column label)
            : base(options, env, name, label)
        {
            Contracts.CheckUserArg(options.LearningRate > 0, nameof(options.LearningRate), UserErrorPositive);
            Contracts.CheckUserArg(!options.ResetWeightsAfterXExamples.HasValue || options.ResetWeightsAfterXExamples > 0, nameof(options.ResetWeightsAfterXExamples), UserErrorPositive);
 
            // Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible.
            Contracts.CheckUserArg(0 <= options.L2Regularization && options.L2Regularization < 0.5, nameof(options.L2Regularization), "must be in range [0, 0.5)");
            Contracts.CheckUserArg(options.RecencyGain >= 0, nameof(options.RecencyGain), UserErrorNonNegative);
            Contracts.CheckUserArg(options.AveragedTolerance >= 0, nameof(options.AveragedTolerance), UserErrorNonNegative);
            // Verify user didn't specify parameters that conflict
            Contracts.Check(!options.LazyUpdate || !options.RecencyGainMultiplicative && options.RecencyGain == 0, "Cannot have both recency gain and lazy updates.");
 
            AveragedLinearTrainerOptions = options;
        }
    }
}