File: Evaluators\MamlEvaluator.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.CommandLine;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// This interface is used by Maml components (the <see cref="EvaluateCommand"/>, the <see cref="CrossValidationCommand"/>
    /// and the <see cref="EvaluateTransform"/> to evaluate, print and save the results.
    /// The input <see cref="RoleMappedData"/> to the <see cref="IEvaluator.Evaluate"/> and the <see cref="IEvaluator.GetPerInstanceMetrics"/> methods
    /// should be assumed to contain only the following column roles: label, group, weight and name. Any other columns needed for
    /// evaluation should be searched for by name in the <see cref="RoleMappedData.Schema"/>.
    /// </summary>
    [BestFriend]
    internal interface IMamlEvaluator : IEvaluator
    {
        /// <summary>
        /// Print the aggregate metrics to the console.
        /// </summary>
        void PrintFoldResults(IChannel ch, Dictionary<string, IDataView> metrics);
 
        /// <summary>
        /// Combine the overall metrics from multiple folds into a single data view.
        /// </summary>
        /// <param name="metrics"></param>
        /// <returns></returns>
        IDataView GetOverallResults(params IDataView[] metrics);
 
        /// <summary>
        /// Handles custom metrics (such as p/r curves for binary classification, or group summary results for ranking) from one
        /// or more folds. Implementations of this method typically creates a single data view for the custom metric and saves it
        /// to a user specified file.
        /// </summary>
        void PrintAdditionalMetrics(IChannel ch, params Dictionary<string, IDataView>[] metrics);
 
        /// <summary>
        /// Create a data view containing only the columns that are saved as per-instance results by Maml commands.
        /// </summary>
        IDataView GetPerInstanceDataViewToSave(RoleMappedData perInstance);
    }
 
    /// <summary>
    /// A base class implementation of <see cref="IMamlEvaluator"/>. The <see cref="Evaluate"/> and <see cref="IEvaluator.GetPerInstanceMetrics"/>
    /// methods create a new <see cref="RoleMappedData"/> containing all the columns needed for evaluation, and call the corresponding
    /// methods on an <see cref="IEvaluator"/> of the appropriate type.
    /// </summary>
    internal abstract class MamlEvaluatorBase : IMamlEvaluator
    {
        public abstract class ArgumentsBase : EvaluateInputBase
        {
            // Standard columns.
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels.", ShortName = "lab")]
            public string LabelColumn;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Weight column name.", ShortName = "weight")]
            public string WeightColumn;
 
            // Score columns.
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Score column name.", ShortName = "score")]
            public string ScoreColumn;
 
            // Stratification columns.
 
            [Argument(ArgumentType.Multiple, HelpText = "Stratification column name.", Name = "StratColumn", ShortName = "strat")]
            public string[] StratColumns;
        }
 
        internal static RoleMappedSchema.ColumnRole Strat = "Strat";
        [BestFriend]
        private protected readonly IHost Host;
 
        [BestFriend]
        private protected readonly string ScoreColumnKind;
        [BestFriend]
        private protected readonly string ScoreCol;
        [BestFriend]
        private protected readonly string LabelCol;
        [BestFriend]
        private protected readonly string WeightCol;
        [BestFriend]
        private protected readonly string[] StratCols;
 
        [BestFriend]
        private protected abstract IEvaluator Evaluator { get; }
 
        [BestFriend]
        private protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string scoreColumnKind, string registrationName)
        {
            Contracts.CheckValue(env, nameof(env));
            Host = env.Register(registrationName);
            ScoreColumnKind = scoreColumnKind;
            ScoreCol = args.ScoreColumn;
            LabelCol = args.LabelColumn;
            WeightCol = args.WeightColumn;
            StratCols = args.StratColumns;
        }
 
        Dictionary<string, IDataView> IEvaluator.Evaluate(RoleMappedData data)
        {
            data = new RoleMappedData(data.Data, GetInputColumnRoles(data.Schema, needStrat: true));
            return Evaluator.Evaluate(data);
        }
 
        [BestFriend]
        private protected IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles(RoleMappedSchema schema, bool needStrat = false, bool needName = false)
        {
            Host.CheckValue(schema, nameof(schema));
 
            var roles = !needStrat || StratCols == null
                ? Enumerable.Empty<KeyValuePair<RoleMappedSchema.ColumnRole, string>>()
                : StratCols.Select(col => RoleMappedSchema.CreatePair(Strat, col));
 
            if (needName && schema.Name.HasValue)
                roles = AnnotationUtils.Prepend(roles, RoleMappedSchema.ColumnRole.Name.Bind(schema.Name.Value.Name));
 
            return roles.Concat(GetInputColumnRolesCore(schema));
        }
 
        /// <summary>
        /// All the input columns needed by an evaluator should be added here.
        /// The base class implementation gets the score column, the label column (if exists) and the weight column (if exists).
        /// Override if additional columns are needed.
        /// </summary>
        [BestFriend]
        private protected virtual IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRolesCore(RoleMappedSchema schema)
        {
            // Get the score column information.
            var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(ArgumentsBase.ScoreColumn),
                ScoreColumnKind);
            yield return RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, scoreCol.Name);
 
            // Get the label column information.
            string label = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label);
            yield return RoleMappedSchema.ColumnRole.Label.Bind(label);
 
            string weight = EvaluateUtils.GetColName(WeightCol, schema.Weight, null);
            if (!string.IsNullOrEmpty(weight))
                yield return RoleMappedSchema.ColumnRole.Weight.Bind(weight);
        }
 
        public virtual IEnumerable<MetricColumn> GetOverallMetricColumns()
        {
            return Evaluator.GetOverallMetricColumns();
        }
 
        void IMamlEvaluator.PrintFoldResults(IChannel ch, Dictionary<string, IDataView> metrics)
        {
            Host.CheckValue(ch, nameof(ch));
            Host.CheckValue(metrics, nameof(metrics));
            PrintFoldResultsCore(ch, metrics);
        }
 
        /// <summary>
        /// This method simply prints the overall metrics using EvaluateUtils.PrintConfusionMatrixAndPerFoldResults.
        /// Override if something else is needed.
        /// </summary>
        [BestFriend]
        private protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
        {
            ch.AssertValue(ch);
            ch.AssertValue(metrics);
 
            IDataView fold;
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold))
                throw ch.Except("No overall metrics found");
 
            string weightedMetrics;
            string unweightedMetrics = MetricWriter.GetPerFoldResults(Host, fold, out weightedMetrics);
            if (!string.IsNullOrEmpty(weightedMetrics))
                ch.Info(weightedMetrics);
            ch.Info(unweightedMetrics);
        }
 
        IDataView IMamlEvaluator.GetOverallResults(params IDataView[] metrics)
        {
            Host.CheckNonEmpty(metrics, nameof(metrics));
            var overall = CombineOverallMetricsCore(metrics);
            return GetOverallResultsCore(overall);
        }
 
        [BestFriend]
        private protected virtual IDataView CombineOverallMetricsCore(IDataView[] metrics)
        {
            return EvaluateUtils.ConcatenateOverallMetrics(Host, metrics);
        }
 
        [BestFriend]
        private protected virtual IDataView GetOverallResultsCore(IDataView overall)
        {
            return overall;
        }
 
        void IMamlEvaluator.PrintAdditionalMetrics(IChannel ch, params Dictionary<string, IDataView>[] metrics)
        {
            Host.CheckValue(ch, nameof(ch));
            Host.CheckNonEmpty(metrics, nameof(metrics));
            PrintAdditionalMetricsCore(ch, metrics);
        }
 
        /// <summary>
        /// This method simply prints the overall metrics using EvaluateUtils.PrintOverallMetrics.
        /// Override if something else is needed.
        /// </summary>
        [BestFriend]
        private protected virtual void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)
        {
        }
 
        IDataTransform IEvaluator.GetPerInstanceMetrics(RoleMappedData scoredData)
        {
            Host.AssertValue(scoredData);
 
            var schema = scoredData.Schema;
            var dataEval = new RoleMappedData(scoredData.Data, GetInputColumnRoles(schema));
            return Evaluator.GetPerInstanceMetrics(dataEval);
        }
 
        private IDataView WrapPerInstance(RoleMappedData perInst)
        {
            var idv = perInst.Data;
 
            // Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap
            // the per-instance data computed by the evaluator in a SelectColumnsTransform.
            var cols = new List<(string name, string source)>();
            var colsToKeep = new List<string>();
 
            // If perInst is the result of cross-validation and contains a fold Id column, include it.
            int foldCol;
            if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol))
                colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex);
 
            // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform.
            if (perInst.Schema.Name?.Name is string nameName)
            {
                cols.Add(("Instance", nameName));
                colsToKeep.Add("Instance");
            }
            else
            {
                var args = new GenerateNumberTransform.Options();
                args.Columns = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } };
                args.UseCounter = true;
                idv = new GenerateNumberTransform(Host, args, idv);
                colsToKeep.Add("Instance");
            }
 
            // Maml outputs the weight column if it exists.
            if (perInst.Schema.Weight?.Name is string weightName)
                colsToKeep.Add(weightName);
 
            // Get the other columns from the evaluator.
            foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema))
                colsToKeep.Add(col);
 
            idv = new ColumnCopyingTransformer(Host, cols.ToArray()).Transform(idv);
            idv = ColumnSelectingTransformer.CreateKeep(Host, idv, colsToKeep.ToArray());
            return GetPerInstanceMetricsCore(idv, perInst.Schema);
        }
 
        /// <summary>
        /// The perInst dataview contains all a name column (called Instance), the FoldId, Label and Weight columns if
        /// they exist, and all the columns returned by <see cref="GetPerInstanceColumnsToSave"/>.
        /// It should be overridden only if additional processing is needed, such as dropping slots in the "top k scores" column
        /// in the multi-class case.
        /// </summary>
        [BestFriend]
        private protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
        {
            return perInst;
        }
 
        IDataView IMamlEvaluator.GetPerInstanceDataViewToSave(RoleMappedData perInstance)
        {
            Host.CheckValue(perInstance, nameof(perInstance));
            var data = new RoleMappedData(perInstance.Data, GetInputColumnRoles(perInstance.Schema, needName: true));
            return WrapPerInstance(data);
        }
 
        /// <summary>
        /// Returns the names of the columns that should be saved in the per-instance results file. These can include
        /// the columns generated by the corresponding <see cref="IRowMapper"/>, or any of the input columns used by
        /// it. The Name and Weight columns should not be included, since the base class includes them automatically.
        /// </summary>
        [BestFriend]
        private protected abstract IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema);
    }
}