File: API\BinaryClassificationExperiment.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.ML.AutoML.Tuner;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Trainers.LightGbm;
using static Microsoft.ML.TrainCatalogBase;
 
namespace Microsoft.ML.AutoML
{
    /// <summary>
    /// Settings for AutoML experiments on binary classification datasets.
    /// </summary>
    public sealed class BinaryExperimentSettings : ExperimentSettings
    {
        /// <summary>
        /// Metric that AutoML will try to optimize over the course of the experiment.
        /// </summary>
        /// <value>The default value is <see cref="BinaryClassificationMetric.Accuracy"/>.</value>
        public BinaryClassificationMetric OptimizingMetric { get; set; }
 
        /// <summary>
        /// Collection of trainers the AutoML experiment can leverage.
        /// </summary>
        /// <value>The default value is a collection auto-populated with all possible trainers (all values of <see cref="BinaryClassificationTrainer" />).</value>
        public ICollection<BinaryClassificationTrainer> Trainers { get; }
 
        /// <summary>
        /// Set if use <see cref="AutoZeroTuner"/> for hyper-parameter optimization, default to false.
        /// </summary>
        public bool UseAutoZeroTuner { get; set; }
 
        /// <summary>
        /// Initializes a new instance of <see cref="BinaryExperimentSettings"/>.
        /// </summary>
        public BinaryExperimentSettings()
        {
            OptimizingMetric = BinaryClassificationMetric.Accuracy;
            Trainers = Enum.GetValues(typeof(BinaryClassificationTrainer)).OfType<BinaryClassificationTrainer>().ToList();
            UseAutoZeroTuner = false;
        }
    }
 
    /// <summary>
    /// Binary classification metric that AutoML will aim to optimize in its sweeping process during an experiment.
    /// </summary>
    public enum BinaryClassificationMetric
    {
        /// <summary>
        /// See <see cref="BinaryClassificationMetrics.Accuracy"/>.
        /// </summary>
        Accuracy,
 
        /// <summary>
        /// See <see cref="BinaryClassificationMetrics.AreaUnderRocCurve"/>.
        /// </summary>
        AreaUnderRocCurve,
 
        /// <summary>
        /// See <see cref="BinaryClassificationMetrics.AreaUnderPrecisionRecallCurve"/>.
        /// </summary>
        AreaUnderPrecisionRecallCurve,
 
        /// <summary>
        /// See <see cref="BinaryClassificationMetrics.F1Score"/>.
        /// </summary>
        F1Score,
 
        /// <summary>
        /// See <see cref="BinaryClassificationMetrics.PositivePrecision"/>.
        /// </summary>
        PositivePrecision,
 
        /// <summary>
        /// See <see cref="BinaryClassificationMetrics.PositiveRecall"/>.
        /// </summary>
        PositiveRecall,
 
        /// <summary>
        /// See <see cref="BinaryClassificationMetrics.NegativePrecision"/>.
        /// </summary>
        NegativePrecision,
 
        /// <summary>
        /// See <see cref="BinaryClassificationMetrics.NegativeRecall"/>.
        /// </summary>
        NegativeRecall,
    }
 
    /// <summary>
    /// Enumeration of ML.NET binary classification trainers used by AutoML.
    /// </summary>
    public enum BinaryClassificationTrainer
    {
        /// <summary>
        /// See <see cref="FastForestBinaryTrainer"/>.
        /// </summary>
        FastForest,
 
        /// <summary>
        /// See <see cref="FastTreeBinaryTrainer"/>.
        /// </summary>
        FastTree,
 
        /// <summary>
        /// See <see cref="LightGbmBinaryTrainer"/>.
        /// </summary>
        LightGbm,
 
        /// <summary>
        /// See <see cref="LbfgsLogisticRegressionBinaryTrainer"/>.
        /// </summary>
        LbfgsLogisticRegression,
 
        /// <summary>
        /// See <see cref="SdcaLogisticRegressionBinaryTrainer"/>.
        /// </summary>
        SdcaLogisticRegression,
    }
 
    /// <summary>
    /// AutoML experiment on binary classification datasets.
    /// </summary>
    /// <example>
    /// <format type="text/markdown">
    /// <![CDATA[
    ///  [!code-csharp[BinaryClassificationExperiment](~/../docs/samples/docs/samples/Microsoft.ML.AutoML.Samples/BinaryClassificationExperiment.cs)]
    /// ]]></format>
    /// </example>
    public sealed class BinaryClassificationExperiment : ExperimentBase<BinaryClassificationMetrics, BinaryExperimentSettings>
    {
        private AutoMLExperiment _experiment;
        private const string Features = "__Features__";
        private SweepablePipeline _pipeline;
 
        internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSettings settings)
            : base(context,
                  new BinaryMetricsAgent(context, settings.OptimizingMetric),
                  new OptimizingMetricInfo(settings.OptimizingMetric),
                  settings,
                  TaskKind.BinaryClassification,
                  TrainerExtensionUtil.GetTrainerNames(settings.Trainers))
        {
            _experiment = context.Auto().CreateExperiment();
            if (settings.MaximumMemoryUsageInMegaByte is double d)
            {
                _experiment.SetMaximumMemoryUsageInMegaByte(d);
            }
            _experiment.SetMaxModelToExplore(settings.MaxModels);
            _experiment.SetTrainingTimeInSeconds(settings.MaxExperimentTimeInSeconds);
        }
 
        public override ExperimentResult<BinaryClassificationMetrics> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<BinaryClassificationMetrics>> progressHandler = null)
        {
            var label = columnInformation.LabelColumnName;
            _experiment.SetBinaryClassificationMetric(Settings.OptimizingMetric, label);
 
            // Cross val threshold for # of dataset rows --
            // If dataset has < threshold # of rows, use cross val.
            // Else, run experiment using train-validate split.
            const int crossValRowCountThreshold = 15000;
            var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold);
            // TODO
            // split cross validation result according to sample key as well.
            if (rowCount < crossValRowCountThreshold)
            {
                const int numCrossValFolds = 10;
                _experiment.SetDataset(trainData, numCrossValFolds);
            }
            else
            {
                var splitData = Context.Data.TrainTestSplit(trainData);
                _experiment.SetDataset(splitData.TrainSet, splitData.TestSet);
            }
            _pipeline = CreateBinaryClassificationPipeline(trainData, columnInformation, preFeaturizer);
            _experiment.SetPipeline(_pipeline);
 
            // set monitor
            TrialResultMonitor<BinaryClassificationMetrics> monitor = null;
            _experiment.SetMonitor((provider) =>
            {
                var channel = provider.GetService<IChannel>();
                var pipeline = provider.GetService<SweepablePipeline>();
                monitor = new TrialResultMonitor<BinaryClassificationMetrics>(channel, pipeline);
                monitor.OnTrialCompleted += (o, e) =>
                {
                    var detail = BestResultUtil.ToRunDetail(Context, e, _pipeline);
                    progressHandler?.Report(detail);
                };
 
                return monitor;
            });
            _experiment = PostConfigureAutoMLExperiment(_experiment);
            _experiment.Run();
 
            var runDetails = monitor.RunDetails.Select(e => BestResultUtil.ToRunDetail(Context, e, _pipeline));
            var bestRun = BestResultUtil.ToRunDetail(Context, monitor.BestRun, _pipeline);
            var result = new ExperimentResult<BinaryClassificationMetrics>(runDetails, bestRun);
 
            return result;
        }
 
        public override ExperimentResult<BinaryClassificationMetrics> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<BinaryClassificationMetrics>> progressHandler = null)
        {
            var label = columnInformation.LabelColumnName;
            _experiment.SetBinaryClassificationMetric(Settings.OptimizingMetric, label);
            _experiment.SetDataset(trainData, validationData);
            _pipeline = CreateBinaryClassificationPipeline(trainData, columnInformation, preFeaturizer);
            _experiment.SetPipeline(_pipeline);
 
            // set monitor
            TrialResultMonitor<BinaryClassificationMetrics> monitor = null;
            _experiment.SetMonitor((provider) =>
            {
                var channel = provider.GetService<IChannel>();
                var pipeline = provider.GetService<SweepablePipeline>();
                monitor = new TrialResultMonitor<BinaryClassificationMetrics>(channel, pipeline);
                monitor.OnTrialCompleted += (o, e) =>
                {
                    var detail = BestResultUtil.ToRunDetail(Context, e, _pipeline);
                    progressHandler?.Report(detail);
                };
 
                return monitor;
            });
            _experiment = PostConfigureAutoMLExperiment(_experiment);
            _experiment.Run();
 
            var runDetails = monitor.RunDetails.Select(e => BestResultUtil.ToRunDetail(Context, e, _pipeline));
            var bestRun = BestResultUtil.ToRunDetail(Context, monitor.BestRun, _pipeline);
            var result = new ExperimentResult<BinaryClassificationMetrics>(runDetails, bestRun);
 
            return result;
        }
 
        public override ExperimentResult<BinaryClassificationMetrics> Execute(IDataView trainData, IDataView validationData, string labelColumnName = "Label", IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<BinaryClassificationMetrics>> progressHandler = null)
        {
            var columnInformation = new ColumnInformation()
            {
                LabelColumnName = labelColumnName,
            };
 
            return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler);
        }
 
        public override ExperimentResult<BinaryClassificationMetrics> Execute(IDataView trainData, string labelColumnName = "Label", string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<BinaryClassificationMetrics>> progressHandler = null)
        {
            var columnInformation = new ColumnInformation()
            {
                LabelColumnName = labelColumnName,
                SamplingKeyColumnName = samplingKeyColumn,
            };
 
            return Execute(trainData, columnInformation, preFeaturizer, progressHandler);
        }
 
        public override CrossValidationExperimentResult<BinaryClassificationMetrics> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<CrossValidationRunDetail<BinaryClassificationMetrics>> progressHandler = null)
        {
            var label = columnInformation.LabelColumnName;
            _experiment.SetBinaryClassificationMetric(Settings.OptimizingMetric, label);
            _experiment.SetDataset(trainData, (int)numberOfCVFolds);
            _pipeline = CreateBinaryClassificationPipeline(trainData, columnInformation, preFeaturizer);
            _experiment.SetPipeline(_pipeline);
 
            // set monitor
            TrialResultMonitor<BinaryClassificationMetrics> monitor = null;
            _experiment.SetMonitor((provider) =>
            {
                var channel = provider.GetService<IChannel>();
                var pipeline = provider.GetService<SweepablePipeline>();
                monitor = new TrialResultMonitor<BinaryClassificationMetrics>(channel, pipeline);
                monitor.OnTrialCompleted += (o, e) =>
                {
                    var detail = BestResultUtil.ToCrossValidationRunDetail(Context, e, _pipeline);
                    progressHandler?.Report(detail);
                };
 
                return monitor;
            });
 
            _experiment = PostConfigureAutoMLExperiment(_experiment);
            _experiment.Run();
 
            var runDetails = monitor.RunDetails.Select(e => BestResultUtil.ToCrossValidationRunDetail(Context, e, _pipeline));
            var bestResult = BestResultUtil.ToCrossValidationRunDetail(Context, monitor.BestRun, _pipeline);
 
            var result = new CrossValidationExperimentResult<BinaryClassificationMetrics>(runDetails, bestResult);
 
            return result;
        }
 
        public override CrossValidationExperimentResult<BinaryClassificationMetrics> Execute(IDataView trainData, uint numberOfCVFolds, string labelColumnName = "Label", string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<CrossValidationRunDetail<BinaryClassificationMetrics>> progressHandler = null)
        {
            var columnInformation = new ColumnInformation()
            {
                LabelColumnName = labelColumnName,
                SamplingKeyColumnName = samplingKeyColumn,
            };
 
            return Execute(trainData, numberOfCVFolds, columnInformation, preFeaturizer, progressHandler);
        }
 
        private protected override RunDetail<BinaryClassificationMetrics> GetBestRun(IEnumerable<RunDetail<BinaryClassificationMetrics>> results)
        {
            return BestResultUtil.GetBestRun(results, MetricsAgent, OptimizingMetricInfo.IsMaximizing);
        }
 
        private protected override CrossValidationRunDetail<BinaryClassificationMetrics> GetBestCrossValRun(IEnumerable<CrossValidationRunDetail<BinaryClassificationMetrics>> results)
        {
            return BestResultUtil.GetBestRun(results, MetricsAgent, OptimizingMetricInfo.IsMaximizing);
        }
 
        private SweepablePipeline CreateBinaryClassificationPipeline(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null)
        {
            var useSdca = Settings.Trainers.Contains(BinaryClassificationTrainer.SdcaLogisticRegression);
            var uselbfgs = Settings.Trainers.Contains(BinaryClassificationTrainer.LbfgsLogisticRegression);
            var useLgbm = Settings.Trainers.Contains(BinaryClassificationTrainer.LightGbm);
            var useFastForest = Settings.Trainers.Contains(BinaryClassificationTrainer.FastForest);
            var useFastTree = Settings.Trainers.Contains(BinaryClassificationTrainer.FastTree);
 
            if (preFeaturizer != null)
            {
                return preFeaturizer.Append(Context.Auto().Featurizer(trainData, columnInformation, Features))
                                        .Append(Context.Auto().BinaryClassification(labelColumnName: columnInformation.LabelColumnName, useSdcaLogisticRegression: useSdca, useFastTree: useFastTree, useLgbm: useLgbm, useLbfgsLogisticRegression: uselbfgs, useFastForest: useFastForest, featureColumnName: Features));
            }
            else
            {
                return Context.Auto().Featurizer(trainData, columnInformation, Features)
                           .Append(Context.Auto().BinaryClassification(labelColumnName: columnInformation.LabelColumnName, useSdcaLogisticRegression: useSdca, useFastTree: useFastTree, useLgbm: useLgbm, useLbfgsLogisticRegression: uselbfgs, useFastForest: useFastForest, featureColumnName: Features));
            }
        }
 
        private AutoMLExperiment PostConfigureAutoMLExperiment(AutoMLExperiment experiment)
        {
            experiment.SetTrialRunner<BinaryClassificationRunner>();
            if (Settings.UseAutoZeroTuner)
            {
                experiment.SetTuner<AutoZeroTuner>();
            }
 
            return experiment;
        }
    }
 
    internal class BinaryClassificationRunner : ITrialRunner
    {
        private MLContext _context;
        private readonly IDatasetManager _datasetManager;
        private readonly IMLContextManager _contextManager;
        private readonly IMetricManager _metricManager;
        private readonly SweepablePipeline _pipeline;
        private readonly Random _rnd;
        public BinaryClassificationRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
        {
            _context = contextManager.CreateMLContext();
            _contextManager = contextManager;
            _datasetManager = datasetManager;
            _metricManager = metricManager;
            _pipeline = pipeline;
            _rnd = settings.Seed.HasValue ? new Random(settings.Seed.Value) : new Random();
        }
 
        public void Dispose()
        {
            _context.CancelExecution();
            _context = null;
        }
 
        public TrialResult Run(TrialSettings settings)
        {
            if (_metricManager is BinaryMetricManager metricManager)
            {
                var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
                var pipeline = _pipeline.BuildFromOption(_context, parameter);
                // _context will be cancelled after training. So returned pipeline need to be created on a 
                // new MLContext.
                var refitContext = _contextManager.CreateMLContext();
                var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter);
                if (_datasetManager is ICrossValidateDatasetManager datasetManager)
                {
                    var stopWatch = new Stopwatch();
                    stopWatch.Start();
                    var fold = datasetManager.Fold;
                    var metrics = _context.BinaryClassification.CrossValidateNonCalibrated(datasetManager.Dataset, pipeline, fold, metricManager.LabelColumn);
 
                    // now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
                    var res = metrics[_rnd.Next(fold)];
                    var model = res.Model;
                    var metric = GetMetric(metricManager.Metric, res.Metrics);
 
                    var loss = metricManager.IsMaximize ? -metric : metric;
                    stopWatch.Stop();
 
 
                    return new TrialResult<BinaryClassificationMetrics>()
                    {
                        Loss = loss,
                        Metric = metric,
                        Model = model,
                        TrialSettings = settings,
                        DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
                        Metrics = res.Metrics,
                        CrossValidationMetrics = metrics,
                        Pipeline = refitPipeline,
                    };
                }
 
                if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
                {
                    var stopWatch = new Stopwatch();
                    stopWatch.Start();
                    var model = pipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_context, settings));
                    var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_context, settings));
                    var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
                    var metric = GetMetric(metricManager.Metric, metrics);
                    var loss = metricManager.IsMaximize ? -metric : metric;
 
                    stopWatch.Stop();
 
 
                    return new TrialResult<BinaryClassificationMetrics>()
                    {
                        Loss = loss,
                        Metric = metric,
                        Model = model,
                        TrialSettings = settings,
                        DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
                        Metrics = metrics,
                        Pipeline = refitPipeline,
                    };
                }
            }
 
            throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
        }
 
        public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
        {
            try
            {
                using (var ctRegistration = ct.Register(() =>
                {
                    _context?.CancelExecution();
                }))
                {
                    return Task.FromResult(Run(settings));
                }
            }
            catch (Exception ex) when (ct.IsCancellationRequested)
            {
                throw new OperationCanceledException(ex.Message, ex.InnerException);
            }
            catch (Exception)
            {
                throw;
            }
        }
 
        private double GetMetric(BinaryClassificationMetric metric, BinaryClassificationMetrics metrics)
        {
            return metric switch
            {
                BinaryClassificationMetric.PositivePrecision => metrics.PositivePrecision,
                BinaryClassificationMetric.Accuracy => metrics.Accuracy,
                BinaryClassificationMetric.AreaUnderRocCurve => metrics.AreaUnderRocCurve,
                BinaryClassificationMetric.AreaUnderPrecisionRecallCurve => metrics.AreaUnderPrecisionRecallCurve,
                BinaryClassificationMetric.PositiveRecall => metrics.PositiveRecall,
                BinaryClassificationMetric.NegativePrecision => metrics.NegativePrecision,
                BinaryClassificationMetric.NegativeRecall => metrics.NegativeRecall,
                BinaryClassificationMetric.F1Score => metrics.F1Score,
                _ => throw new NotImplementedException($"{metric} is not supported!"),
            };
        }
    }
}