File: API\RankingExperiment.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Trainers.LightGbm;
 
namespace Microsoft.ML.AutoML
{
    public sealed class RankingExperimentSettings : ExperimentSettings
    {
        /// <summary>
        /// Metric that AutoML will try to optimize over the course of the experiment.
        /// </summary>
        /// <value>The default value is <see cref="RankingMetric" />.</value>
        public RankingMetric OptimizingMetric { get; set; }
 
        /// <summary>
        /// Collection of trainers the AutoML experiment can leverage.
        /// </summary>
        /// <value>
        /// The default value is a collection auto-populated with all possible trainers (all values of <see cref="RankingTrainer" />).
        /// </value>
        public ICollection<RankingTrainer> Trainers { get; }
 
        /// <summary>
        /// Maximum truncation level for computing (N)DCG
        /// </summary>
        /// <value>
        /// The default value is 10.
        /// </value>
        public uint OptimizationMetricTruncationLevel { get; set; }
 
        /// <summary>
        /// Initializes a new instance of <see cref="RankingExperimentSettings"/>.
        /// </summary>
        public RankingExperimentSettings()
        {
            OptimizingMetric = RankingMetric.Ndcg;
            Trainers = Enum.GetValues(typeof(RankingTrainer)).OfType<RankingTrainer>().ToList();
            OptimizationMetricTruncationLevel = 10;
        }
    }
    public enum RankingMetric
    {
        /// <summary>
        /// See <see cref="RankingMetrics.NormalizedDiscountedCumulativeGains"/>.
        /// </summary>
        Ndcg,
        /// <summary>
        /// See <see cref="RankingMetrics.DiscountedCumulativeGains"/>.
        /// </summary>
        Dcg
    }
    /// <summary>
    /// Enumeration of ML.NET ranking trainers used by AutoML.
    /// </summary>
    public enum RankingTrainer
    {
        /// <summary>
        /// See <see cref="LightGbmRankingTrainer"/>.
        /// </summary>
        LightGbmRanking,
        /// <summary>
        /// See <see cref="FastTreeRankingTrainer"/>.
        /// </summary>
        FastTreeRanking
    }
 
    /// <summary>
    /// Extension methods that operate over ranking experiment run results.
    /// </summary>
    public static class RankingExperimentResultExtensions
    {
        /// <summary>
        /// Select the best run from an enumeration of experiment runs.
        /// </summary>
        /// <param name="results">Enumeration of AutoML experiment run results.</param>
        /// <param name="metric">Metric to consider when selecting the best run.</param>
        /// <param name="optimizationMetricTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 10.</param>
        /// <returns>The best experiment run.</returns>
        public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, uint optimizationMetricTruncationLevel = 10)
        {
            var metricsAgent = new RankingMetricsAgent(null, metric, optimizationMetricTruncationLevel);
            var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
            return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
        }
 
        /// <summary>
        /// Select the best run from an enumeration of experiment cross validation runs.
        /// </summary>
        /// <param name="results">Enumeration of AutoML experiment cross validation run results.</param>
        /// <param name="metric">Metric to consider when selecting the best run.</param>
        /// <param name="optimizationMetricTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 10.</param>
        /// <returns>The best experiment run.</returns>
        public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, uint optimizationMetricTruncationLevel = 10)
        {
            var metricsAgent = new RankingMetricsAgent(null, metric, optimizationMetricTruncationLevel);
            var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
            return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
        }
    }
 
    /// <summary>
    /// AutoML experiment on ranking datasets.
    /// </summary>
    /// <example>
    /// <format type="text/markdown">
    /// <![CDATA[
    ///  [!code-csharp[RankingExperiment](~/../docs/samples/docs/samples/Microsoft.ML.AutoML.Samples/RankingExperiment.cs)]
    /// ]]></format>
    /// </example>
    public sealed class RankingExperiment : ExperimentBase<RankingMetrics, RankingExperimentSettings>
    {
        internal RankingExperiment(MLContext context, RankingExperimentSettings settings)
            : base(context,
                  new RankingMetricsAgent(context, settings.OptimizingMetric, settings.OptimizationMetricTruncationLevel),
                  new OptimizingMetricInfo(settings.OptimizingMetric),
                  settings,
                  TaskKind.Ranking,
                  TrainerExtensionUtil.GetTrainerNames(settings.Trainers))
        {
        }
 
        private protected override CrossValidationRunDetail<RankingMetrics> GetBestCrossValRun(IEnumerable<CrossValidationRunDetail<RankingMetrics>> results)
        {
            return BestResultUtil.GetBestRun(results, MetricsAgent, OptimizingMetricInfo.IsMaximizing);
        }
 
        private protected override RunDetail<RankingMetrics> GetBestRun(IEnumerable<RunDetail<RankingMetrics>> results)
        {
            return BestResultUtil.GetBestRun(results, MetricsAgent, OptimizingMetricInfo.IsMaximizing);
        }
    }
}