File: Commands\TestCommand.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
[assembly: LoadableClass(TestCommand.Summary, typeof(TestCommand), typeof(TestCommand.Arguments), typeof(SignatureCommand),
    "Test Predictor", "Test")]
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// This command is essentially chaining together <see cref="ScoreCommand"/> and
    /// <see cref="EvaluateCommand"/>, without the need to save the intermediary scored data.
    /// </summary>
    [BestFriend]
    internal sealed class TestCommand : DataCommand.ImplBase<TestCommand.Arguments>
    {
        public sealed class Arguments : DataCommand.ArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2)]
            public string FeatureColumn = DefaultColumnNames.Features;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3)]
            public string LabelColumn = DefaultColumnNames.Label;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4)]
            public string WeightColumn = DefaultColumnNames.Weight;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 5)]
            public string GroupColumn = DefaultColumnNames.GroupId;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 6)]
            public string NameColumn = DefaultColumnNames.Name;
 
            [Argument(ArgumentType.LastOccurrenceWins,
                HelpText = "Columns with custom kinds declared through key assignments, for example, col[Kind]=Name to assign column named 'Name' kind 'Kind'",
                Name = "CustomColumn", ShortName = "col", SortOrder = 10)]
            public KeyValuePair<string, string>[] CustomColumns;
 
            [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "<Auto>", SortOrder = 101, SignatureType = typeof(SignatureDataScorer))]
            public IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> Scorer;
 
            [Argument(ArgumentType.Multiple, HelpText = "Evaluator to use", ShortName = "eval", NullName = "<Auto>", SortOrder = 102, SignatureType = typeof(SignatureMamlEvaluator))]
            public IComponentFactory<IMamlEvaluator> Evaluator;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Results summary filename", ShortName = "sf")]
            public string SummaryFilename;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "File to save per-instance predictions and metrics to",
                ShortName = "dout")]
            public string OutputDataFile;
        }
 
        internal const string Summary = "Scores and evaluates a data file.";
 
        public TestCommand(IHostEnvironment env, Arguments args)
            : base(env, args, nameof(TestCommand))
        {
            Host.CheckUserArg(!string.IsNullOrEmpty(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile), "The input model file is required.");
            Utils.CheckOptionalUserDirectory(ImplOptions.SummaryFilename, nameof(ImplOptions.SummaryFilename));
            Utils.CheckOptionalUserDirectory(ImplOptions.OutputDataFile, nameof(ImplOptions.OutputDataFile));
        }
 
        public override void Run()
        {
            string command = "Test";
            using (var ch = Host.Start(command))
            using (var server = InitServer(ch))
            {
                var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments());
                ch.Info("maml.exe {0} {1}", command, settings);
 
                SendTelemetry(Host);
                using (new TimerScope(Host, ch))
                {
                    RunCore(ch);
                }
            }
        }
 
        private void RunCore(IChannel ch)
        {
            ch.Trace("Constructing data pipeline");
            ILegacyDataLoader loader;
            IPredictor predictor;
            RoleMappedSchema trainSchema;
            LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader);
            ch.AssertValue(predictor);
            ch.AssertValueOrNull(trainSchema);
            ch.AssertValue(loader);
 
            ch.Trace("Binding columns");
            var schema = loader.Schema;
            string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.LabelColumn),
                ImplOptions.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.FeatureColumn),
                ImplOptions.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn),
                ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.WeightColumn),
                ImplOptions.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.NameColumn),
                ImplOptions.NameColumn, DefaultColumnNames.Name);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
 
            // Score.
            ch.Trace("Scoring and evaluating");
            ch.Assert(ImplOptions.Scorer == null || ImplOptions.Scorer is ICommandLineComponentFactory, "TestCommand should only be used from the command line.");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(ImplOptions.Scorer, predictor, loader, features, group, customCols, Host, trainSchema);
 
            // Evaluate.
            var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ??
                EvaluateUtils.GetEvaluator(Host, scorePipe.Schema);
            var data = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols);
            var metrics = evaluator.Evaluate(data);
            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
                throw ch.Except("No overall metrics found");
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary<string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile))
            {
                var perInst = evaluator.GetPerInstanceMetrics(data);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv);
            }
        }
    }
}