File: Commands\EvaluateCommand.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Text.RegularExpressions;
using Microsoft.ML;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
[assembly: LoadableClass(EvaluateTransform.Summary, typeof(IDataTransform), typeof(EvaluateTransform), typeof(EvaluateTransform.Arguments), typeof(SignatureDataTransform),
    "Evaluate Predictor", "Evaluate")]
 
[assembly: LoadableClass(EvaluateCommand.Summary, typeof(EvaluateCommand), typeof(EvaluateCommand.Arguments), typeof(SignatureCommand),
    "Evaluate Predictor", "Evaluate")]
 
namespace Microsoft.ML.Data
{
    // REVIEW: For simplicity (since this is currently the case),
    // we assume that all metrics are either numeric, or numeric vectors.
    /// <summary>
    /// This class contains information about an overall metric, namely its name and whether it is a vector
    /// metric or not.
    /// </summary>
    [BestFriend]
    internal sealed class MetricColumn
    {
        /// <summary>
        /// An enum specifying whether the metric should be maximized or minimized while sweeping. 'Info' should be
        /// used for metrics that are irrelevant to the model's quality (such as the number of positive/negative
        /// examples etc.).
        /// </summary>
        public enum Objective
        {
            Maximize,
            Minimize,
            Info,
        }
 
        public readonly string LoadName;
        public readonly bool IsVector;
        public readonly Objective MetricTarget;
        public readonly string Name;
 
        public readonly bool CanBeWeighted;
        private readonly Regex _loadNamePattern;
        private readonly string _groupName;
        private readonly string _nameFormat;
 
        public MetricColumn(string loadName, string name, Objective target = Objective.Maximize, bool canBeWeighted = true,
            bool isVector = false, Regex namePattern = null, string groupName = null, string nameFormat = null)
        {
            Contracts.CheckValue(loadName, nameof(loadName));
            Contracts.CheckValue(name, nameof(name));
 
            LoadName = loadName;
            Name = name;
            MetricTarget = target;
            CanBeWeighted = canBeWeighted;
            IsVector = isVector;
            _loadNamePattern = namePattern;
            _groupName = groupName;
            _nameFormat = nameFormat;
        }
 
        public string GetNameMatch(string input)
        {
            if (_loadNamePattern == null)
            {
                if (input.Equals(LoadName, StringComparison.OrdinalIgnoreCase) || (CanBeWeighted && input == "Weighted" + LoadName))
                    return Name;
                return null;
            }
            if (string.IsNullOrEmpty(_groupName) || string.IsNullOrEmpty(_nameFormat))
                return null;
            var match = _loadNamePattern.Match(input);
            if (!match.Success)
                return null;
            var s = match.Groups[_groupName];
            return string.Format(_nameFormat, s);
        }
    }
 
    // REVIEW: Move this interface to MLCore when IDataTransform is moved there.
    /// <summary>
    /// This is an interface for evaluation. It has two methods: <see cref="Evaluate"/> and <see cref="GetPerInstanceMetrics"/>.
    /// Both take a <see cref="RoleMappedData"/> as input. The <see cref="RoleMappedData"/> is assumed to contain all the column
    /// roles needed for evaluation, including the score column.
    /// </summary>
    [BestFriend]
    internal interface IEvaluator
    {
        /// <summary>
        /// Compute the aggregate metrics. Return a dictionary from the metric kind
        /// (overal/per-fold/confusion matrix/PR-curves etc.), to a data view containing the metric.
        /// </summary>
        Dictionary<string, IDataView> Evaluate(RoleMappedData data);
 
        /// <summary>
        /// Return an <see cref="IDataTransform"/> containing the per-instance results.
        /// </summary>
        IDataTransform GetPerInstanceMetrics(RoleMappedData data);
 
        /// <summary>
        /// Get all the overall metrics returned by this evaluator.
        /// </summary>
        IEnumerable<MetricColumn> GetOverallMetricColumns();
    }
 
    /// <summary>
    /// Signature for creating an <see cref="IEvaluator"/>.
    /// </summary>
    [BestFriend]
    internal delegate void SignatureEvaluator();
    [BestFriend]
    internal delegate void SignatureMamlEvaluator();
 
    internal static class EvaluateTransform
    {
        public sealed class Arguments
        {
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3)]
            public string LabelColumn = DefaultColumnNames.Label;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4)]
            public string WeightColumn = DefaultColumnNames.Weight;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 5)]
            public string GroupColumn = DefaultColumnNames.GroupId;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Columns with custom kinds declared through key assignments, for example, col[Kind]=Name to assign column named 'Name' kind 'Kind'",
                Name = "CustomColumn", ShortName = "col", SortOrder = 10)]
            public KeyValuePair<string, string>[] CustomColumns;
 
            [Argument(ArgumentType.Multiple, HelpText = "Evaluator to use", ShortName = "eval", SignatureType = typeof(SignatureMamlEvaluator))]
            public IComponentFactory<IMamlEvaluator> Evaluator;
        }
 
        internal const string Summary = "Runs a previously trained predictor on the data.";
 
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
 
            using (var ch = env.Register("EvaluateTransform").Start("Create Transform"))
            {
                ch.Trace("Binding columns");
                var schema = input.Schema;
                string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                    args.LabelColumn, DefaultColumnNames.Label);
                string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                    args.GroupColumn, DefaultColumnNames.GroupId);
                string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                    args.WeightColumn, DefaultColumnNames.Weight);
                var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, args.CustomColumns);
 
                ch.Trace("Creating evaluator");
                IMamlEvaluator eval = args.Evaluator?.CreateComponent(env) ??
                    EvaluateUtils.GetEvaluator(env, input.Schema);
 
                var data = new RoleMappedData(input, label, null, group, weight, null, customCols);
                return eval.GetPerInstanceMetrics(data);
            }
        }
    }
 
    internal sealed class EvaluateCommand : DataCommand.ImplBase<EvaluateCommand.Arguments>
    {
        public sealed class Arguments : DataCommand.ArgumentsBase
        {
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3)]
            public string LabelColumn = DefaultColumnNames.Label;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4)]
            public string WeightColumn = DefaultColumnNames.Weight;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 5)]
            public string GroupColumn = DefaultColumnNames.GroupId;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 6)]
            public string NameColumn = DefaultColumnNames.Name;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Columns with custom kinds declared through key assignments, for example, col[Kind]=Name to assign column named 'Name' kind 'Kind'",
                Name = "CustomColumn", ShortName = "col", SortOrder = 10)]
            public KeyValuePair<string, string>[] CustomColumns;
 
            [Argument(ArgumentType.Multiple, HelpText = "Evaluator to use", ShortName = "eval", SignatureType = typeof(SignatureMamlEvaluator))]
            public IComponentFactory<IMamlEvaluator> Evaluator;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Results summary filename", ShortName = "sf")]
            public string SummaryFilename;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "File to save per-instance predictions and metrics to",
                ShortName = "dout")]
            public string OutputDataFile;
        }
 
        internal const string Summary = "Evaluates the metrics for a scored data file.";
 
        public EvaluateCommand(IHostEnvironment env, Arguments args)
            : base(env, args, nameof(EvaluateCommand))
        {
            Utils.CheckOptionalUserDirectory(ImplOptions.SummaryFilename, nameof(ImplOptions.SummaryFilename));
            Utils.CheckOptionalUserDirectory(ImplOptions.OutputDataFile, nameof(ImplOptions.OutputDataFile));
        }
 
        public override void Run()
        {
            using (var ch = Host.Start("Evaluate"))
            {
                RunCore(ch);
            }
        }
 
        private void RunCore(IChannel ch)
        {
            Host.AssertValue(ch);
 
            ch.Trace("Creating loader");
            IDataView view = CreateAndSaveLoader(
                (env, source) => new IO.BinaryLoader(env, new IO.BinaryLoader.Arguments(), source));
 
            ch.Trace("Binding columns");
            var schema = view.Schema;
            string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                ImplOptions.LabelColumn, DefaultColumnNames.Label);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                ImplOptions.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                ImplOptions.NameColumn, DefaultColumnNames.Name);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
 
            ch.Trace("Creating evaluator");
            var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ??
                EvaluateUtils.GetEvaluator(Host, view.Schema);
 
            var data = new RoleMappedData(view, label, null, group, weight, name, customCols);
            var metrics = evaluator.Evaluate(data);
            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
                throw ch.Except("No overall metrics found");
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile))
            {
                var perInst = evaluator.GetPerInstanceMetrics(data);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv);
            }
        }
    }
}