File: Standard\Online\OnlineLinear.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Globalization;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.SearchSpace;
 
namespace Microsoft.ML.Trainers
{
    /// <summary>
    /// Arguments class for online linear trainers.
    /// </summary>
    public abstract class OnlineLinearOptions : TrainerInputBaseWithLabel
    {
        /// <summary>
        /// Number of passes through the training dataset.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter,numIterations", SortOrder = 50)]
        [TGUI(Label = "Number of Iterations", Description = "Number of training iterations through data", SuggestedSweeps = "1,10,100")]
        [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize: 10, isLogScale: true)]
        [Range(1, 512, 1, true)]
        public int NumberOfIterations = OnlineDefault.NumberOfIterations;
 
        /// <summary>
        /// Initial weights and bias, comma-separated.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Initial Weights and bias, comma-separated", ShortName = "initweights")]
        [TGUI(NoSweep = true)]
        internal string InitialWeights;
 
        /// <summary>
        /// Initial weights and bias scale.
        /// </summary>
        /// <value>
        /// This property is only used if the provided value is positive and <see cref="InitialWeights"/> is not specified.
        /// The weights and bias will be randomly selected from InitialWeights * [-0.5,0.5] interval with uniform distribution.
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Init weights diameter", ShortName = "initwts,initWtsDiameter", SortOrder = 140)]
        [TGUI(Label = "Initial Weights Scale", SuggestedSweeps = "0,0.1,0.5,1")]
        [TlcModule.SweepableFloatParamAttribute("InitWtsDiameter", 0.0f, 1.0f, numSteps: 5)]
        [Range(0f, 1f, 0f, false)]
        public float InitialWeightsDiameter = 0;
 
        /// <summary>
        /// Determines whether to shuffle data for each training iteration.
        /// </summary>
        /// <value>
        /// <see langword="true" /> to shuffle data for each training iteration; otherwise, <see langword="false" />.
        /// Default is <see langword="true" />.
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to shuffle for each training iteration", ShortName = "shuf")]
        [TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[] { false, true })]
        public bool Shuffle = true;
 
        [BestFriend]
        internal class OnlineDefault
        {
            public const int NumberOfIterations = 1;
        }
    }
 
    /// <summary>
    /// Base class for online linear trainers. Online trainers can be updated incrementally with additional data.
    /// </summary>
    public abstract class OnlineLinearTrainer<TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
        where TTransformer : ISingleFeaturePredictionTransformer<TModel>
        where TModel : class
    {
        private protected readonly OnlineLinearOptions OnlineLinearTrainerOptions;
        private protected readonly string Name;
 
        /// <summary>
        /// An object to hold the mutable updatable state for the online linear trainers. Specific algorithms should subclass
        /// this, and return the instance via <see cref="MakeState(IChannel, int, LinearModelParameters)"/>.
        /// </summary>
        private protected abstract class TrainStateBase
        {
            // Current iteration state.
 
            /// <summary>
            /// The number of iterations. Incremented by <see cref="BeginIteration(IChannel)"/>.
            /// </summary>
            public int Iteration;
 
            /// <summary>
            /// The number of examples in the current iteration. Incremented by <see cref="ProcessDataInstance(IChannel, in VBuffer{float}, float, float)"/>,
            /// and reset by <see cref="BeginIteration(IChannel)"/>.
            /// </summary>
            public long NumIterExamples;
 
            // Current weights and bias. The weights vector is considered to be scaled by
            // weightsScale. Storing this separately allows us to avoid the overhead of
            // an explicit scaling, which many learning algorithms will attempt to do on
            // each update. Bias is not subject to the weights scale.
 
            /// <summary>
            /// Current weights. The weights vector is considered to be scaled by <see cref="WeightsScale"/>. Storing this separately
            /// allows us to avoid the overhead of an explicit scaling, which some algorithms will attempt to do on each example's update.
            /// </summary>
            public VBuffer<float> Weights;
 
            /// <summary>
            /// The implicit scaling factor for <see cref="Weights"/>. Note that this does not affect <see cref="Bias"/>.
            /// </summary>
            public float WeightsScale;
 
            /// <summary>
            /// The intercept term.
            /// </summary>
            public float Bias;
 
            protected readonly IHost ParentHost;
 
            protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, OnlineLinearTrainer<TTransformer, TModel> parent)
            {
                Contracts.CheckValue(ch, nameof(ch));
                ch.Check(numFeatures > 0, "Cannot train with zero features!");
                ch.AssertValueOrNull(predictor);
                ch.AssertValue(parent);
                ch.Assert(Iteration == 0);
                ch.Assert(Bias == 0);
 
                ParentHost = parent.Host;
 
                ch.Trace("{0} Initializing {1} on {2} features", DateTime.UtcNow, parent.Name, numFeatures);
 
                // We want a dense vector, to prevent memory creation during training
                // unless we have a lot of features.
                if (predictor != null)
                {
                    ((IHaveFeatureWeights)predictor).GetFeatureWeights(ref Weights);
                    VBufferUtils.Densify(ref Weights);
                    Bias = predictor.Bias;
                }
                else if (!string.IsNullOrWhiteSpace(parent.OnlineLinearTrainerOptions.InitialWeights))
                {
                    ch.Info("Initializing weights and bias to " + parent.OnlineLinearTrainerOptions.InitialWeights);
                    string[] weightStr = parent.OnlineLinearTrainerOptions.InitialWeights.Split(',');
                    if (weightStr.Length != numFeatures + 1)
                    {
                        throw ch.Except(
                            "Could not initialize weights from 'initialWeights': expecting {0} values to initialize {1} weights and the intercept",
                            numFeatures + 1, numFeatures);
                    }
 
                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                        weightValues[i] = float.Parse(weightStr[i], CultureInfo.InvariantCulture);
                    Weights = new VBuffer<float>(numFeatures, weightValues);
                    Bias = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture);
                }
                else if (parent.OnlineLinearTrainerOptions.InitialWeightsDiameter > 0)
                {
                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                        weightValues[i] = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                    Weights = new VBuffer<float>(numFeatures, weightValues);
                    Bias = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                }
                else if (numFeatures <= 1000)
                    Weights = VBufferUtils.CreateDense<float>(numFeatures);
                else
                    Weights = VBufferUtils.CreateEmpty<float>(numFeatures);
                WeightsScale = 1;
            }
 
            /// <summary>
            /// Propagates the <see cref="WeightsScale"/> to the <see cref="Weights"/> vector.
            /// </summary>
            private void ScaleWeights()
            {
                if (WeightsScale != 1)
                {
                    VectorUtils.ScaleBy(ref Weights, WeightsScale);
                    WeightsScale = 1;
                }
            }
 
            /// <summary>
            /// Conditionally propagates the <see cref="WeightsScale"/> to the <see cref="Weights"/> vector
            /// when it reaches a scale where additions to weights would start dropping too much precision.
            /// ("Too much" is mostly empirically defined.)
            /// </summary>
            public void ScaleWeightsIfNeeded()
            {
                float absWeightsScale = Math.Abs(WeightsScale);
                if (absWeightsScale < _minWeightScale || absWeightsScale > _maxWeightScale)
                    ScaleWeights();
            }
 
            /// <summary>
            /// Called by <see cref="TrainCore(IChannel, RoleMappedData, TrainStateBase)"/> at the start of a pass over the dataset.
            /// </summary>
            public virtual void BeginIteration(IChannel ch)
            {
                Iteration++;
                NumIterExamples = 0;
 
                ch.Trace("{0} Starting training iteration {1}", DateTime.UtcNow, Iteration);
            }
 
            /// <summary>
            /// Called by <see cref="TrainCore(IChannel, RoleMappedData, TrainStateBase)"/> after a pass over the dataset.
            /// </summary>
            public virtual void FinishIteration(IChannel ch)
            {
                Contracts.Check(NumIterExamples > 0, NoTrainingInstancesMessage);
 
                ch.Trace("{0} Finished training iteration {1}; iterated over {2} examples.",
                    DateTime.UtcNow, Iteration, NumIterExamples);
 
                ScaleWeights();
            }
 
            /// <summary>
            /// This should be overridden by derived classes. This implementation simply increments <see cref="NumIterExamples"/>.
            /// </summary>
            public virtual void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, float label, float weight)
            {
                ch.Assert(FloatUtils.IsFinite(feat.GetValues()));
                ++NumIterExamples;
            }
 
            /// <summary>
            /// Return the raw margin from the decision hyperplane
            /// </summary>
            public float CurrentMargin(in VBuffer<float> feat)
                => Bias + VectorUtils.DotProduct(in feat, in Weights) * WeightsScale;
 
            /// <summary>
            /// The default implementation just calls <see cref="CurrentMargin(in VBuffer{float})"/>.
            /// </summary>
            /// <param name="feat"></param>
            /// <returns></returns>
            public virtual float Margin(in VBuffer<float> feat)
                => CurrentMargin(in feat);
 
            public abstract TModel CreatePredictor();
        }
 
        // Our tolerance for the error induced by the weight scale may depend on our precision.
        private const float _maxWeightScale = 1 << 10; // Exponent ranges 127 to -128, tolerate 10 being cut off that.
        private const float _minWeightScale = 1 / _maxWeightScale;
 
        private protected const string UserErrorPositive = "must be positive";
        private protected const string UserErrorNonNegative = "must be non-negative";
 
        public override TrainerInfo Info { get; }
 
        private protected virtual bool NeedCalibration => false;
 
        private protected OnlineLinearTrainer(OnlineLinearOptions options, IHostEnvironment env, string name, SchemaShape.Column label)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName), label, TrainerUtils.MakeR4ScalarWeightColumn(options.InitialWeights))
        {
            Contracts.CheckValue(options, nameof(options));
            Contracts.CheckUserArg(options.NumberOfIterations > 0, nameof(options.NumberOfIterations), UserErrorPositive);
            Contracts.CheckUserArg(options.InitialWeightsDiameter >= 0, nameof(options.InitialWeightsDiameter), UserErrorNonNegative);
 
            OnlineLinearTrainerOptions = options;
            Name = name;
            // REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
            Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true);
        }
 
        private protected sealed override TModel TrainModelCore(TrainContext context)
        {
            Host.CheckValue(context, nameof(context));
            var initPredictor = context.InitialPredictor;
 
            if (initPredictor is LinearModelParameters initLinearPred)
                initLinearPred = (LinearModelParameters)initPredictor;
            else
            {
                initLinearPred = null;
                Host.CheckParam(initPredictor == null || initLinearPred != null, nameof(context),
                    "Initial predictor was not a linear predictor.");
            }
 
            var data = context.TrainingSet;
            data.CheckFeatureFloatVector(out int numFeatures);
            CheckLabels(data);
 
            using (var ch = Host.Start("Training"))
            {
                var state = MakeState(ch, numFeatures, initLinearPred);
                TrainCore(ch, data, state);
 
                ch.Assert(state.WeightsScale == 1);
                float maxNorm = Math.Max(VectorUtils.MaxNorm(in state.Weights), Math.Abs(state.Bias));
                ch.Check(FloatUtils.IsFinite(maxNorm),
                    "The weights/bias contain invalid values (NaN or Infinite). Potential causes: high learning rates, no normalization, high initial weights, etc.");
                return state.CreatePredictor();
            }
        }
 
        /// <summary>
        /// Continues the training of a <see cref="OnlineLinearTrainer{TTransformer, TModel}"/> using an already trained <paramref name="modelParameters"/> and returns a <see cref="ITransformer"/>.
        /// </summary>
        public TTransformer Fit(IDataView trainData, LinearModelParameters modelParameters)
            => TrainTransformer(trainData, initPredictor: modelParameters);
 
        private protected abstract void CheckLabels(RoleMappedData data);
 
        private void TrainCore(IChannel ch, RoleMappedData data, TrainStateBase state)
        {
            bool shuffle = OnlineLinearTrainerOptions.Shuffle;
            if (shuffle && !data.Data.CanShuffle)
            {
                ch.Warning("Training data does not support shuffling, so ignoring request to shuffle");
                shuffle = false;
            }
 
            var rand = shuffle ? Host.Rand : null;
            CursOpt cursorOpt = CursOpt.Label | CursOpt.Features;
            if (data.Schema.Weight.HasValue)
                cursorOpt |= CursOpt.Weight;
 
            var cursorFactory = new FloatLabelCursor.Factory(data, cursorOpt);
            long numBad = 0;
            while (state.Iteration < OnlineLinearTrainerOptions.NumberOfIterations)
            {
                state.BeginIteration(ch);
 
                using (var cursor = cursorFactory.Create(rand))
                {
                    while (cursor.MoveNext())
                        state.ProcessDataInstance(ch, in cursor.Features, cursor.Label, cursor.Weight);
                    numBad += cursor.BadFeaturesRowCount;
                }
 
                state.FinishIteration(ch);
            }
 
            if (numBad > 0)
            {
                ch.Warning(
                    "Skipped {0} instances with missing features during training (over {1} iterations; {2} inst/iter)",
                    numBad, OnlineLinearTrainerOptions.NumberOfIterations, numBad / OnlineLinearTrainerOptions.NumberOfIterations);
            }
        }
 
        private protected abstract TrainStateBase MakeState(IChannel ch, int numFeatures, LinearModelParameters predictor);
    }
}