File: Commands\TrainTestCommand.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
[assembly: LoadableClass(TrainTestCommand.Summary, typeof(TrainTestCommand), typeof(TrainTestCommand.Arguments), typeof(SignatureCommand),
    "Train Test", TrainTestCommand.LoadName)]
 
namespace Microsoft.ML.Data
{
    [BestFriend]
    internal sealed class TrainTestCommand : DataCommand.ImplBase<TrainTestCommand.Arguments>
    {
        public sealed class Arguments : DataCommand.ArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The test data file", ShortName = "test", SortOrder = 1)]
            public string TestFile;
 
            [Argument(ArgumentType.Multiple, HelpText = "Trainer to use", ShortName = "tr", SignatureType = typeof(SignatureTrainer))]
            public IComponentFactory<ITrainer> Trainer;
 
            [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", NullName = "<Auto>", SortOrder = 101, SignatureType = typeof(SignatureDataScorer))]
            public IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> Scorer;
 
            [Argument(ArgumentType.Multiple, HelpText = "Evaluator to use", ShortName = "eval", NullName = "<Auto>", SortOrder = 102, SignatureType = typeof(SignatureMamlEvaluator))]
            public IComponentFactory<IMamlEvaluator> Evaluator;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Results summary filename", ShortName = "sf")]
            public string SummaryFilename;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2)]
            public string FeatureColumn = DefaultColumnNames.Features;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3)]
            public string LabelColumn = DefaultColumnNames.Label;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4)]
            public string WeightColumn = DefaultColumnNames.Weight;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 5)]
            public string GroupColumn = DefaultColumnNames.GroupId;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 6)]
            public string NameColumn = DefaultColumnNames.Name;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Columns with custom kinds declared through key assignments, for example, col[Kind]=Name to assign column named 'Name' kind 'Kind'",
                Name = "CustomColumn", ShortName = "col", SortOrder = 10)]
            public KeyValuePair<string, string>[] CustomColumns;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Normalize option for the feature column", ShortName = "norm")]
            public NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
 
            [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The validation data file", ShortName = "valid")]
            public string ValidationFile;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether we should cache input training data", ShortName = "cache")]
            public bool? CacheData;
 
            [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
            public IComponentFactory<ICalibratorTrainer> Calibrator = new PlattCalibratorTrainerFactory();
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")]
            public int MaxCalibrationExamples = 1000000000;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "File to save per-instance predictions and metrics to",
                ShortName = "dout")]
            public string OutputDataFile;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether we should load predictor from input model and use it as the initial model state", ShortName = "cont")]
            public bool ContinueTrain;
        }
 
        internal const string Summary = "Trains a predictor using the train file and then scores and evaluates the predictor using the test file.";
        public const string LoadName = "TrainTest";
 
        public TrainTestCommand(IHostEnvironment env, Arguments args)
            : base(env, args, nameof(TrainTestCommand))
        {
            Utils.CheckOptionalUserDirectory(args.SummaryFilename, nameof(args.SummaryFilename));
            Utils.CheckOptionalUserDirectory(args.OutputDataFile, nameof(args.OutputDataFile));
            TrainUtils.CheckTrainer(Host, args.Trainer, args.DataFile);
            if (string.IsNullOrWhiteSpace(args.TestFile))
                throw Host.ExceptUserArg(nameof(args.TestFile), "Test file must be defined.");
        }
 
        public override void Run()
        {
            using (var ch = Host.Start(LoadName))
            using (var server = InitServer(ch))
            {
                var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments());
                string cmd = string.Format("maml.exe {0} {1}", LoadName, settings);
                ch.Info(cmd);
 
                SendTelemetry(Host);
 
                using (new TimerScope(Host, ch))
                {
                    RunCore(ch, cmd);
                }
            }
        }
 
        protected override void SendTelemetryCore(IPipe<TelemetryMessage> pipe)
        {
            SendTelemetryComponent(pipe, ImplOptions.Trainer);
            base.SendTelemetryCore(pipe);
        }
 
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);
 
            ch.Trace("Constructing trainer");
            ITrainer trainer = ImplOptions.Trainer.CreateComponent(Host);
 
            IPredictor inputPredictor = null;
            if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor))
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
 
            ch.Trace("Constructing the training pipeline");
            IDataView trainPipe = CreateLoader();
 
            var schema = trainPipe.Schema;
            string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                ImplOptions.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                ImplOptions.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                ImplOptions.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                ImplOptions.NameColumn, DefaultColumnNames.Name);
 
            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, ImplOptions.NormalizeFeatures);
 
            ch.Trace("Binding columns");
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
            var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols);
 
            RoleMappedData validData = null;
            if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile))
            {
                if (!trainer.Info.SupportsValidation)
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }
 
            // In addition to the training set, some trainers can accept two data sets, validation set and test set,
            // in training phase. The major difference between validation set and test set is that training process may
            // indirectly use validation set to improve the model but the learned model should totally independent of test set.
            // Similar to validation set, the trainer can report the scores computed using test set.
            RoleMappedData testDataUsedInTrainer = null;
            if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile))
            {
                // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided
                // because this is TrainTest command.
                if (trainer.Info.SupportsTest)
                {
                    ch.Trace("Constructing the test pipeline");
                    IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile);
                    testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer);
                    testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames());
                }
            }
 
            var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
                ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer);
 
            ILegacyDataLoader testPipe;
            bool hasOutfile = !string.IsNullOrEmpty(ImplOptions.OutputModelFile);
            var tempFilePath = hasOutfile ? null : Path.Combine(((IHostEnvironmentInternal)Host).TempFilePath, Path.GetRandomFileName());
 
            using (var file = new SimpleFileHandle(ch, hasOutfile ? ImplOptions.OutputModelFile : tempFilePath, true, !hasOutfile))
            {
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
                ch.Trace("Constructing the testing pipeline");
                using (var stream = file.OpenReadStream())
                using (var rep = RepositoryReader.Open(stream, ch))
                    testPipe = LoadLoader(rep, ImplOptions.TestFile, true);
            }
 
            // Score.
            ch.Trace("Scoring and evaluating");
            ch.Assert(ImplOptions.Scorer == null || ImplOptions.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line.");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(ImplOptions.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema);
 
            // Evaluate.
            var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ??
                EvaluateUtils.GetEvaluator(Host, scorePipe.Schema);
            var dataEval = new RoleMappedData(scorePipe, label, features,
                group, weight, name, customCols, opt: true);
            var metrics = evaluator.Evaluate(dataEval);
            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
                throw ch.Except("No overall metrics found");
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary<string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile))
            {
                var perInst = evaluator.GetPerInstanceMetrics(dataEval);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv);
            }
        }
    }
}