|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
namespace Microsoft.ML.AutoML
{
internal class BestResultUtil
{
public static RunDetail<BinaryClassificationMetrics> GetBestRun(IEnumerable<RunDetail<BinaryClassificationMetrics>> results,
BinaryClassificationMetric metric)
{
var metricsAgent = new BinaryMetricsAgent(null, metric);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}
public static RunDetail<RegressionMetrics> GetBestRun(IEnumerable<RunDetail<RegressionMetrics>> results,
RegressionMetric metric)
{
var metricsAgent = new RegressionMetricsAgent(null, metric);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}
public static RunDetail<MulticlassClassificationMetrics> GetBestRun(IEnumerable<RunDetail<MulticlassClassificationMetrics>> results,
MulticlassClassificationMetric metric)
{
var metricsAgent = new MultiMetricsAgent(null, metric);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}
public static RunDetail<RankingMetrics> GetBestRun(IEnumerable<RunDetail<RankingMetrics>> results,
RankingMetric metric, uint dcgTruncationLevel)
{
var metricsAgent = new RankingMetricsAgent(null, metric, dcgTruncationLevel);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}
public static RunDetail<TMetrics> GetBestRun<TMetrics>(IEnumerable<RunDetail<TMetrics>> results,
IMetricsAgent<TMetrics> metricsAgent, bool isMetricMaximizing)
{
results = results.Where(r => r.ValidationMetrics != null);
if (!results.Any()) { return null; }
var scores = results.Select(r => metricsAgent.GetScore(r.ValidationMetrics));
var indexOfBestScore = GetIndexOfBestScore(scores, isMetricMaximizing);
// indexOfBestScore will be -1 if the optimization metric for all models is NaN.
// In this case, return the first model.
indexOfBestScore = indexOfBestScore != -1 ? indexOfBestScore : 0;
return results.ElementAt(indexOfBestScore);
}
public static CrossValidationRunDetail<TMetrics> GetBestRun<TMetrics>(IEnumerable<CrossValidationRunDetail<TMetrics>> results,
IMetricsAgent<TMetrics> metricsAgent, bool isMetricMaximizing)
{
results = results.Where(r => r.Results != null && r.Results.Any(x => x.ValidationMetrics != null));
if (!results.Any()) { return null; }
var scores = results.Select(r => r.Results.Average(x => metricsAgent.GetScore(x.ValidationMetrics)));
var indexOfBestScore = GetIndexOfBestScore(scores, isMetricMaximizing);
// indexOfBestScore will be -1 if the optimization metric for all models is NaN.
// In this case, return the first model.
indexOfBestScore = indexOfBestScore != -1 ? indexOfBestScore : 0;
return results.ElementAt(indexOfBestScore);
}
public static IEnumerable<(RunDetail<T>, int)> GetTopNRunResults<T>(IEnumerable<RunDetail<T>> results,
IMetricsAgent<T> metricsAgent, int n, bool isMetricMaximizing)
{
results = results.Where(r => r.ValidationMetrics != null);
if (!results.Any()) { return null; }
var indexedValues = results.Select((k, v) => (k, v));
IEnumerable<(RunDetail<T>, int)> orderedResults;
if (isMetricMaximizing)
{
orderedResults = indexedValues.OrderByDescending(t => metricsAgent.GetScore(t.Item1.ValidationMetrics));
}
else
{
orderedResults = indexedValues.OrderBy(t => metricsAgent.GetScore(t.Item1.ValidationMetrics));
}
return orderedResults.Take(n);
}
public static int GetIndexOfBestScore(IEnumerable<double> scores, bool isMetricMaximizing)
{
return isMetricMaximizing ? GetIndexOfMaxScore(scores) : GetIndexOfMinScore(scores);
}
public static RunDetail<TMetrics> ToRunDetail<TMetrics>(MLContext context, TrialResult<TMetrics> result, SweepablePipeline pipeline)
where TMetrics : class
{
var parameter = result.TrialSettings.Parameter;
var trainerName = pipeline.ToString(parameter);
var modelContainer = new ModelContainer(context, result.Model);
var detail = new RunDetail<TMetrics>(trainerName, result.Pipeline, null, modelContainer, result.Metrics, result.Exception);
detail.RuntimeInSeconds = result.DurationInMilliseconds / 1000;
return detail;
}
public static CrossValidationRunDetail<TMetrics> ToCrossValidationRunDetail<TMetrics>(MLContext context, TrialResult<TMetrics> result, SweepablePipeline pipeline)
where TMetrics : class
{
var parameter = result.TrialSettings.Parameter;
var trainerName = pipeline.ToString(parameter);
var crossValidationResult = result.CrossValidationMetrics.Select(m => new TrainResult<TMetrics>(new ModelContainer(context, m.Model), m.Metrics, result.Exception));
var detail = new CrossValidationRunDetail<TMetrics>(trainerName, result.Pipeline, null, crossValidationResult);
detail.RuntimeInSeconds = result.DurationInMilliseconds / 1000;
return detail;
}
private static int GetIndexOfMinScore(IEnumerable<double> scores)
{
var minScore = double.PositiveInfinity;
var minIndex = -1;
for (var i = 0; i < scores.Count(); i++)
{
if (scores.ElementAt(i) < minScore)
{
minScore = scores.ElementAt(i);
minIndex = i;
}
}
return minIndex;
}
private static int GetIndexOfMaxScore(IEnumerable<double> scores)
{
var maxScore = double.NegativeInfinity;
var maxIndex = -1;
for (var i = 0; i < scores.Count(); i++)
{
if (scores.ElementAt(i) > maxScore)
{
maxScore = scores.ElementAt(i);
maxIndex = i;
}
}
return maxIndex;
}
}
}
|