File: Experiment\Runners\CrossValSummaryRunner.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.AutoML
{
    internal class CrossValSummaryRunner<TMetrics> : IRunner<RunDetail<TMetrics>>
        where TMetrics : class
    {
        private readonly MLContext _context;
        private readonly IDataView[] _trainDatasets;
        private readonly IDataView[] _validDatasets;
        private readonly IMetricsAgent<TMetrics> _metricsAgent;
        private readonly IEstimator<ITransformer> _preFeaturizer;
        private readonly ITransformer[] _preprocessorTransforms;
        private readonly string _groupIdColumn;
        private readonly string _labelColumn;
        private readonly OptimizingMetricInfo _optimizingMetricInfo;
        private readonly IChannel _logger;
        private readonly DataViewSchema _modelInputSchema;
 
        public CrossValSummaryRunner(MLContext context,
            IDataView[] trainDatasets,
            IDataView[] validDatasets,
            IMetricsAgent<TMetrics> metricsAgent,
            IEstimator<ITransformer> preFeaturizer,
            ITransformer[] preprocessorTransforms,
            string groupIdColumn,
            string labelColumn,
            OptimizingMetricInfo optimizingMetricInfo,
            IChannel logger)
        {
            _context = context;
            _trainDatasets = trainDatasets;
            _validDatasets = validDatasets;
            _metricsAgent = metricsAgent;
            _preFeaturizer = preFeaturizer;
            _preprocessorTransforms = preprocessorTransforms;
            _groupIdColumn = groupIdColumn;
            _labelColumn = labelColumn;
            _optimizingMetricInfo = optimizingMetricInfo;
            _logger = logger;
            _modelInputSchema = trainDatasets[0].Schema;
        }
 
        public (SuggestedPipelineRunDetail suggestedPipelineRunDetail, RunDetail<TMetrics> runDetail)
            Run(SuggestedPipeline pipeline, DirectoryInfo modelDirectory, int iterationNum)
        {
            var trainResults = new List<(ModelContainer model, TMetrics metrics, Exception exception, double score)>();
 
            for (var i = 0; i < _trainDatasets.Length; i++)
            {
                var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
                var trainResult = RunnerUtil.TrainAndScorePipeline(pipeline.GetContext(), pipeline, _trainDatasets[i], _validDatasets[i],
                    _groupIdColumn, _labelColumn, _metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema,
                    _logger);
                trainResults.Add(trainResult);
            }
 
            var allRunsSucceeded = trainResults.All(r => r.exception == null);
            if (!allRunsSucceeded)
            {
                var firstException = trainResults.First(r => r.exception != null).exception;
                var errorRunDetail = new SuggestedPipelineRunDetail<TMetrics>(pipeline, double.NaN, false, null, null, firstException);
                return (errorRunDetail, errorRunDetail.ToIterationResult(_preFeaturizer));
            }
 
            // Get the model from the best fold
            var bestFoldIndex = BestResultUtil.GetIndexOfBestScore(trainResults.Select(r => r.score), _optimizingMetricInfo.IsMaximizing);
            // bestFoldIndex will be -1 if the optimization metric for all folds is NaN.
            // In this case, return model from the first fold.
            bestFoldIndex = bestFoldIndex != -1 ? bestFoldIndex : 0;
            var bestModel = trainResults.ElementAt(bestFoldIndex).model;
 
            // Get the average metrics across all folds
            var avgScore = GetAverageOfNonNaNScores(trainResults.Select(x => x.score));
            var indexClosestToAvg = GetIndexClosestToAverage(trainResults.Select(r => r.score), avgScore);
            var metricsClosestToAvg = trainResults[indexClosestToAvg].metrics;
            var avgMetrics = GetAverageMetrics(trainResults.Select(x => x.metrics), metricsClosestToAvg);
 
            // Build result objects
            var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail<TMetrics>(pipeline, avgScore, allRunsSucceeded, avgMetrics, bestModel, null);
            var runDetail = suggestedPipelineRunDetail.ToIterationResult(_preFeaturizer);
            return (suggestedPipelineRunDetail, runDetail);
        }
 
        private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetrics metricsClosestToAvg)
        {
            if (typeof(TMetrics) == typeof(BinaryClassificationMetrics))
            {
                var newMetrics = metrics.Select(x => x as BinaryClassificationMetrics);
                Contracts.Assert(newMetrics != null);
 
                var result = new BinaryClassificationMetrics(
                    auc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderRocCurve)),
                    accuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.Accuracy)),
                    positivePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositivePrecision)),
                    positiveRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositiveRecall)),
                    negativePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativePrecision)),
                    negativeRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativeRecall)),
                    f1Score: GetAverageOfNonNaNScores(newMetrics.Select(x => x.F1Score)),
                    auprc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderPrecisionRecallCurve)),
                    // Return ConfusionMatrix from the fold closest to average score
                    confusionMatrix: (metricsClosestToAvg as BinaryClassificationMetrics).ConfusionMatrix);
                return result as TMetrics;
            }
 
            if (typeof(TMetrics) == typeof(MulticlassClassificationMetrics))
            {
                var newMetrics = metrics.Select(x => x as MulticlassClassificationMetrics);
                Contracts.Assert(newMetrics != null);
 
                var result = new MulticlassClassificationMetrics(
                    accuracyMicro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MicroAccuracy)),
                    accuracyMacro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MacroAccuracy)),
                    logLoss: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLoss)),
                    logLossReduction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLossReduction)),
                    topKPredictionCount: newMetrics.ElementAt(0).TopKPredictionCount,
                    topKAccuracies: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.TopKAccuracyForAllK)),
                    perClassLogLoss: (metricsClosestToAvg as MulticlassClassificationMetrics).PerClassLogLoss.ToArray(),
                    confusionMatrix: (metricsClosestToAvg as MulticlassClassificationMetrics).ConfusionMatrix);
                return result as TMetrics;
            }
 
            if (typeof(TMetrics) == typeof(RegressionMetrics))
            {
                var newMetrics = metrics.Select(x => x as RegressionMetrics);
                Contracts.Assert(newMetrics != null);
 
                var result = new RegressionMetrics(
                    l1: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanAbsoluteError)),
                    l2: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanSquaredError)),
                    rms: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RootMeanSquaredError)),
                    lossFunction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LossFunction)),
                    rSquared: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RSquared)));
                return result as TMetrics;
            }
 
            if (typeof(TMetrics) == typeof(RankingMetrics))
            {
                var newMetrics = metrics.Select(x => x as RankingMetrics);
                Contracts.Assert(newMetrics != null);
 
                var result = new RankingMetrics(
                    dcg: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.DiscountedCumulativeGains)),
                    ndcg: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.NormalizedDiscountedCumulativeGains)));
                return result as TMetrics;
            }
 
            throw new NotImplementedException($"Metric {typeof(TMetrics)} not implemented");
        }
 
        private static double[] GetAverageOfNonNaNScoresInNestedEnumerable(IEnumerable<IEnumerable<double>> results)
        {
            if (results.All(result => result == null))
            {
                // If all nested enumerables are null, we say the average is a null enumerable as well.
                // This is expected to happen on Multiclass metrics where the TopKAccuracyForAllK
                // array can be null if the topKPredictionCount isn't a valid number.
                // In that case all of the "results" enumerables will be null anyway, and so
                // returning null is the expected solution.
                return null;
            }
 
            // In case there are only some null elements, we'll ignore them:
            results = results.Where(result => result != null);
 
            double[] arr = new double[results.ElementAt(0).Count()];
            for (int i = 0; i < arr.Length; i++)
            {
                arr[i] = GetAverageOfNonNaNScores(results.Select(x => x.ElementAt(i)));
            }
            return arr;
        }
 
        private static double GetAverageOfNonNaNScores(IEnumerable<double> results)
        {
            var newResults = results.Where(r => !double.IsNaN(r));
            // Return NaN iff all scores are NaN
            if (newResults.Count() == 0)
                return double.NaN;
            // Return average of non-NaN scores otherwise
            return newResults.Average(r => r);
        }
 
        /// <summary>
        /// return the index of value from <paramref name="values"/> that closest to <paramref name="average"/>. If <paramref name="average"/> is NaN, +/- inf, the first, max/min value's index will be return.
        /// </summary>
        private static int GetIndexClosestToAverage(IEnumerable<double> values, double average)
        {
            // Average will be NaN iff all values are NaN.
            // Return the first index in this case.
            if (double.IsNaN(average))
                return 0;
 
            // Return the max value's index if average is positive inf.
            if (double.IsPositiveInfinity(average))
                return values.ToList().IndexOf(values.Max());
 
            // Return the min value's index if average is negative inf.
            if (double.IsNegativeInfinity(average))
                return values.ToList().IndexOf(values.Min());
 
            int avgFoldIndex = -1;
            var smallestDistFromAvg = double.PositiveInfinity;
            for (var i = 0; i < values.Count(); i++)
            {
                var value = values.ElementAt(i);
                if (double.IsNaN(value))
                    continue;
                var distFromAvg = Math.Abs(value - average);
                if (distFromAvg < smallestDistFromAvg)
                {
                    smallestDistFromAvg = distFromAvg;
                    avgFoldIndex = i;
                }
            }
            return avgFoldIndex;
        }
    }
}