File: Trainer\EnsembleTrainerBase.cs
Web Access
Project: src\src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj (Microsoft.ML.Ensemble)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.Ensemble.SubsetSelector;
 
namespace Microsoft.ML.Trainers.Ensemble
{
    using Stopwatch = System.Diagnostics.Stopwatch;
 
    internal abstract class EnsembleTrainerBase<TOutput, TSelector, TCombiner> : ITrainer<IPredictor>
         where TSelector : class, ISubModelSelector<TOutput>
         where TCombiner : class, IOutputCombiner<TOutput>
    {
        public abstract class ArgumentsBase : TrainerInputBaseWithLabel
        {
#pragma warning disable CS0649 // These are set via reflection.
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Number of models per batch. If not specified, will default to 50 if there is only one base predictor, " +
                "or the number of base predictors otherwise.", ShortName = "nm", SortOrder = 3)]
            [TGUI(Label = "Number of Models per batch")]
            public int? NumModels;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Batch size", ShortName = "bs", SortOrder = 107)]
            [TGUI(Label = "Batch Size",
                Description =
                "Number of instances to be loaded in memory to create an ensemble out of it. All the instances will be loaded if the value is -1.")]
            public int BatchSize = -1;
 
            [Argument(ArgumentType.Multiple, HelpText = "Sampling Type", ShortName = "st", SortOrder = 2)]
            [TGUI(Label = "Sampling Type", Description = "Subset Selection Algorithm to induce the base learner.Sub-settings can be used to select the features")]
            public ISupportSubsetSelectorFactory SamplingType = new BootstrapSelector.Arguments();
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "All the base learners will run asynchronously if the value is true", ShortName = "tp", SortOrder = 106)]
            [TGUI(Label = "Train parallel", Description = "All the base learners will run asynchronously if the value is true")]
            public bool TrainParallel;
 
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "True, if metrics for each model need to be evaluated and shown in comparison table. This is done by using validation set if available or the training set",
                ShortName = "sm", SortOrder = 108)]
            [TGUI(Label = "Show Sub-Model Metrics")]
            public bool ShowMetrics;
 
            internal abstract IComponentFactory<ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>>[] GetPredictorFactories();
#pragma warning restore CS0649
        }
 
        private const int DefaultNumModels = 50;
        /// <summary> Command-line arguments </summary>
        private protected readonly ArgumentsBase Args;
        private protected readonly int NumModels;
        private protected readonly IHost Host;
 
        /// <summary> Ensemble members </summary>
        private protected readonly ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>[] Trainers;
 
        private readonly ISubsetSelector _subsetSelector;
        private protected ISubModelSelector<TOutput> SubModelSelector;
        private protected IOutputCombiner<TOutput> Combiner;
 
        public TrainerInfo Info { get; }
 
        PredictionKind ITrainer.PredictionKind => PredictionKind;
        private protected abstract PredictionKind PredictionKind { get; }
 
        private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, string name)
        {
            Contracts.CheckValue(env, nameof(env));
            Host = env.Register(name);
 
            Args = args;
 
            using (var ch = Host.Start("Init"))
            {
                var predictorFactories = Args.GetPredictorFactories();
                ch.CheckUserArg(Utils.Size(predictorFactories) > 0, nameof(EnsembleTrainer.Arguments.BasePredictors), "This should have at-least one value");
 
                NumModels = Args.NumModels ??
                    (predictorFactories.Length == 1 ? DefaultNumModels : predictorFactories.Length);
 
                ch.CheckUserArg(NumModels > 0, nameof(Args.NumModels), "Must be positive, or null to indicate numModels is the number of base predictors");
 
                if (Utils.Size(predictorFactories) > NumModels)
                    ch.Warning("The base predictor count is greater than models count. Some of the base predictors will be ignored.");
 
                _subsetSelector = Args.SamplingType.CreateComponent(Host);
 
                Trainers = new ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>[NumModels];
                for (int i = 0; i < Trainers.Length; i++)
                    Trainers[i] = predictorFactories[i % predictorFactories.Length].CreateComponent(Host);
                // We infer normalization and calibration preferences from the trainers. However, even if the internal trainers
                // don't need caching we are performing multiple passes over the data, so it is probably appropriate to always cache.
                Info = new TrainerInfo(
                    normalization: Trainers.Any(t => t.Info.NeedNormalization),
                    calibration: Trainers.Any(t => t.Info.NeedCalibration));
            }
        }
 
        IPredictor ITrainer<IPredictor>.Train(TrainContext context)
        {
            Host.CheckValue(context, nameof(context));
 
            using (var ch = Host.Start("Training"))
            {
                return TrainCore(ch, context.TrainingSet);
            }
        }
 
        IPredictor ITrainer.Train(TrainContext context)
            => ((ITrainer<IPredictor>)this).Train(context);
 
        private IPredictor TrainCore(IChannel ch, RoleMappedData data)
        {
            Host.AssertValue(ch);
            ch.AssertValue(data);
 
            // 1. Subset Selection
            var stackingTrainer = Combiner as IStackingTrainer<TOutput>;
 
            //REVIEW: Implement stacking for Batch mode.
            ch.CheckUserArg(stackingTrainer == null || Args.BatchSize <= 0, nameof(Args.BatchSize), "Stacking works only with Non-batch mode");
 
            var validationDataSetProportion = SubModelSelector.ValidationDatasetProportion;
            if (stackingTrainer != null)
                validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion);
 
            var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager;
            var models = new List<FeatureSubsetModel<TOutput>>();
 
            _subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion);
            int batchNumber = 1;
            foreach (var batch in _subsetSelector.GetBatches(Host.Rand))
            {
                // 2. Core train
                ch.Info("Training {0} learners for the batch {1}", Trainers.Length, batchNumber++);
                var batchModels = new FeatureSubsetModel<TOutput>[Trainers.Length];
 
                Parallel.ForEach(_subsetSelector.GetSubsets(batch, Host.Rand),
                    new ParallelOptions() { MaxDegreeOfParallelism = Args.TrainParallel ? -1 : 1 },
                    (subset, state, index) =>
                    {
                        ch.Info("Beginning training model {0} of {1}", index + 1, Trainers.Length);
                        Stopwatch sw = Stopwatch.StartNew();
                        try
                        {
                            if (EnsureMinimumFeaturesSelected(subset))
                            {
                                // REVIEW: How to pass the role mappings to the trainer?
                                var model = new FeatureSubsetModel<TOutput>(
                                    Trainers[(int)index].Fit(subset.Data.Data).Model,
                                    subset.SelectedFeatures,
                                    null);
                                SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics);
                                batchModels[(int)index] = model;
                            }
                        }
                        catch (Exception ex)
                        {
                            ch.Assert(batchModels[(int)index] == null);
                            ch.Warning(ex.Sensitivity(), "Trainer {0} of {1} was not learned properly due to the exception '{2}' and will not be added to models.",
                                index + 1, Trainers.Length, ex.Message);
                        }
                        ch.Info("Trainer {0} of {1} finished in {2}", index + 1, Trainers.Length, sw.Elapsed);
                    });
 
                var modelsList = batchModels.Where(m => m != null).ToList();
                if (Args.ShowMetrics)
                    PrintMetrics(ch, modelsList);
 
                modelsList = SubModelSelector.Prune(modelsList).ToList();
 
                if (stackingTrainer != null)
                    stackingTrainer.Train(modelsList, _subsetSelector.GetTestData(null, batch), Host);
 
                models.AddRange(modelsList);
                int modelSize = Utils.Size(models);
                if (modelSize < Utils.Size(Trainers))
                    ch.Warning("{0} of {1} trainings failed.", Utils.Size(Trainers) - modelSize, Utils.Size(Trainers));
                ch.Check(modelSize > 0, "Ensemble training resulted in no valid models.");
            }
            return CreatePredictor(models);
        }
 
        private protected abstract IPredictor CreatePredictor(List<FeatureSubsetModel<TOutput>> models);
 
        private bool EnsureMinimumFeaturesSelected(Subset subset)
        {
            if (subset.SelectedFeatures == null)
                return true;
            for (int i = 0; i < subset.SelectedFeatures.Count; i++)
            {
                if (subset.SelectedFeatures[i])
                    return true;
            }
 
            return false;
        }
 
        private protected virtual void PrintMetrics(IChannel ch, List<FeatureSubsetModel<TOutput>> models)
        {
            // REVIEW: The formatting of this method is bizarre and seemingly not even self-consistent
            // w.r.t. its usage of |. Is this intentional?
            if (models.Count == 0 || models[0].Metrics == null)
                return;
 
            ch.Info("{0}| Name of Model |", string.Join("", models[0].Metrics.Select(m => string.Format("| {0} |", m.Key))));
 
            foreach (var model in models)
                ch.Info("{0}{1}", string.Join("", model.Metrics.Select(m => string.Format("| {0} |", m.Value))), model.Predictor.GetType().Name);
        }
 
        private protected static FeatureSubsetModel<TOutput>[] CreateModels<T>(List<FeatureSubsetModel<TOutput>> models) where T : IPredictorProducing<TOutput>
        {
            var subsetModels = new FeatureSubsetModel<TOutput>[models.Count];
            for (int i = 0; i < models.Count; i++)
            {
                subsetModels[i] = new FeatureSubsetModel<TOutput>(
                    (T)models[i].Predictor,
                    models[i].SelectedFeatures,
                    models[i].Metrics);
            }
            return subsetModels;
        }
    }
}