File: Standard\StochasticTrainerBase.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
namespace Microsoft.ML.Trainers
{
    public abstract class StochasticTrainerBase<TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
        where TTransformer : ISingleFeaturePredictionTransformer<TModel>
        where TModel : class
    {
        public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = default)
            : base(host, feature, label, weight)
        {
        }
 
        /// <summary>
        /// Whether data is to be shuffled every epoch.
        /// </summary>
        private protected abstract bool ShuffleData { get; }
 
        private static readonly TrainerInfo _info = new TrainerInfo();
        public override TrainerInfo Info => _info;
 
        private protected override TModel TrainModelCore(TrainContext context)
        {
            Host.CheckValue(context, nameof(context));
            using (var ch = Host.Start("Training"))
            {
                var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount);
                var initPred = context.InitialPredictor;
                // Try extract linear model from calibrated predictor.
                var linInitPred = (initPred as IWeaklyTypedCalibratedModelParameters)?.WeaklyTypedSubModel as LinearModelParameters;
                // If the initial predictor is not calibrated, it should be a linear model.
                linInitPred = linInitPred ?? initPred as LinearModelParameters;
                Host.CheckParam(context.InitialPredictor == null || linInitPred != null, nameof(context),
                    "Initial predictor was not a linear predictor.");
                return TrainCore(ch, preparedData, linInitPred, weightSetCount);
            }
        }
 
        private protected virtual int ComputeNumThreads(FloatLabelCursor.Factory cursorFactory)
            => Math.Min(8, Math.Max(1, Environment.ProcessorCount / 2));
 
        /// <summary>
        /// This method ensures that the data meets the requirements of this trainer and its
        /// subclasses, injects necessary transforms, and throws if it couldn't meet them.
        /// </summary>
        /// <param name="ch">The channel</param>
        /// <param name="examples">The training examples</param>
        /// <param name="weightSetCount">Gets the length of weights and bias array. For binary classification and regression,
        /// this is 1. For multi-class classification, this equals the number of classes on the label.</param>
        /// <returns>A potentially modified version of <paramref name="examples"/></returns>
        private protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount)
        {
            ch.AssertValue(examples);
            CheckLabel(examples, out weightSetCount);
            examples.CheckFeatureFloatVector();
            var idvToShuffle = examples.Data;
            IDataView idvToFeedTrain;
            if (idvToShuffle.CanShuffle)
                idvToFeedTrain = idvToShuffle;
            else
            {
                var shuffleArgs = new RowShufflingTransformer.Options
                {
                    PoolOnly = false,
                    ForceShuffle = ShuffleData
                };
                idvToFeedTrain = new RowShufflingTransformer(Host, shuffleArgs, idvToShuffle);
            }
 
            ch.Assert(idvToFeedTrain.CanShuffle);
 
            var roles = examples.Schema.GetColumnRoleNames();
            var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles);
 
            ch.Assert(examplesToFeedTrain.Schema.Label.HasValue);
            ch.Assert(examplesToFeedTrain.Schema.Feature.HasValue);
            if (examples.Schema.Weight.HasValue)
                ch.Assert(examplesToFeedTrain.Schema.Weight.HasValue);
 
            ch.Check(examplesToFeedTrain.Schema.Feature.Value.Type is VectorDataViewType vecType && vecType.Size > 0, "Training set has no features, aborting training.");
            return examplesToFeedTrain;
        }
 
        private protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount);
 
        private protected abstract void CheckLabel(RoleMappedData examples, out int weightSetCount);
    }
}