|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(TrainCommand.Summary, typeof(TrainCommand), typeof(TrainCommand.Arguments), typeof(SignatureCommand),
"Train Predictor", "Train")]
namespace Microsoft.ML.Data
{
using ColumnRole = RoleMappedSchema.ColumnRole;
[BestFriend]
internal enum NormalizeOption
{
No,
Warn,
Auto,
Yes
}
[BestFriend]
internal sealed class TrainCommand : DataCommand.ImplBase<TrainCommand.Arguments>
{
public sealed class Arguments : DataCommand.ArgumentsBase
{
// REVIEW: We need some better way to handle auto/none, possibly with
// the hypothetical Maybe<string> structure.
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2)]
public string FeatureColumn = DefaultColumnNames.Features;
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3)]
public string LabelColumn = DefaultColumnNames.Label;
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4)]
public string WeightColumn = DefaultColumnNames.Weight;
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 5)]
public string GroupColumn = DefaultColumnNames.GroupId;
[Argument(ArgumentType.AtMostOnce, HelpText = "Name column name", ShortName = "name", SortOrder = 6)]
public string NameColumn = DefaultColumnNames.Name;
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Columns with custom kinds declared through key assignments, for example, col[Kind]=Name to assign column named 'Name' kind 'Kind'",
Name = "CustomColumn", ShortName = "col", SortOrder = 10)]
public KeyValuePair<string, string>[] CustomColumns;
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Normalize option for the feature column", ShortName = "norm")]
public NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
[Argument(ArgumentType.Multiple, HelpText = "Trainer to use", ShortName = "tr", SignatureType = typeof(SignatureTrainer))]
public IComponentFactory<ITrainer> Trainer;
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The validation data file", ShortName = "valid")]
public string ValidationFile;
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The test data file", ShortName = "test")]
public string TestFile;
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether we should cache input training data", ShortName = "cache")]
public bool? CacheData;
[Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "<None>", SignatureType = typeof(SignatureCalibrator))]
public IComponentFactory<ICalibratorTrainer> Calibrator = new PlattCalibratorTrainerFactory();
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")]
public int MaxCalibrationExamples = 1000000000;
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether we should load predictor from input model and use it as the initial model state", ShortName = "cont")]
public bool ContinueTrain;
}
internal const string Summary = "Trains a predictor.";
private readonly IComponentFactory<ITrainer> _trainer;
private readonly string _labelColumn;
private readonly string _featureColumn;
private readonly string _groupColumn;
private readonly string _weightColumn;
private readonly string _nameColumn;
public TrainCommand(IHostEnvironment env, Arguments args)
: base(env, args, nameof(TrainCommand))
{
Host.CheckNonWhiteSpace(args.OutputModelFile, nameof(args.OutputModelFile));
TrainUtils.CheckTrainer(Host, args.Trainer, args.DataFile);
_trainer = args.Trainer;
_labelColumn = args.LabelColumn;
_featureColumn = args.FeatureColumn;
_groupColumn = args.GroupColumn;
_weightColumn = args.WeightColumn;
_nameColumn = args.NameColumn;
}
public override void Run()
{
string command = "Train";
using (var ch = Host.Start(command))
using (var server = InitServer(ch))
{
var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments());
string cmd = string.Format("maml.exe {0} {1}", command, settings);
ch.Info(cmd);
SendTelemetry(Host);
using (new TimerScope(Host, ch))
{
RunCore(ch, cmd);
}
}
}
protected override void SendTelemetryCore(IPipe<TelemetryMessage> pipe)
{
SendTelemetryComponent(pipe, _trainer);
base.SendTelemetryCore(pipe);
}
private void RunCore(IChannel ch, string cmd)
{
Host.AssertValue(ch);
Host.AssertNonEmpty(cmd);
ch.Trace("Constructing trainer");
ITrainer trainer = _trainer.CreateComponent(Host);
IPredictor inputPredictor = null;
if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor))
ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
ch.Trace("Constructing data pipeline");
IDataView view = CreateLoader();
var schema = view.Schema;
var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), _labelColumn, DefaultColumnNames.Label);
var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), _featureColumn, DefaultColumnNames.Features);
var group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), _groupColumn, DefaultColumnNames.GroupId);
var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), _weightColumn, DefaultColumnNames.Weight);
var name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), _nameColumn, DefaultColumnNames.Name);
TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, ImplOptions.NormalizeFeatures);
ch.Trace("Binding columns");
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
var data = new RoleMappedData(view, label, feature, group, weight, name, customCols);
// REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands.
RoleMappedData validData = null;
if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile))
{
if (!trainer.Info.SupportsValidation)
{
ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
}
else
{
ch.Trace("Constructing the validation pipeline");
IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile);
validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe);
validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
}
}
// In addition to the training set, some trainers can accept two extra data sets, validation set and test set,
// in training phase. The major difference between validation set and test set is that training process may
// indirectly use validation set to improve the model but the learned model should totally independent of test set.
// Similar to validation set, the trainer can report the scores computed using test set.
RoleMappedData testDataUsedInTrainer = null;
if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile))
{
// In contrast to the if-else block for validation above, we do not throw a warning if test file is provided
// because this is TrainTest command.
if (trainer.Info.SupportsTest)
{
ch.Trace("Constructing the test pipeline");
IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile);
testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, testPipeUsedInTrainer);
testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames());
}
}
var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer);
using (var file = Host.CreateOutputFile(ImplOptions.OutputModelFile))
TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
}
}
[BestFriend]
internal static class TrainUtils
{
public static void CheckTrainer(IExceptionContext ectx, IComponentFactory<ITrainer> trainer, string dataFile)
{
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(trainer, nameof(TrainCommand.Arguments.Trainer), "A trainer is required.");
if (string.IsNullOrWhiteSpace(dataFile))
throw ectx.ExceptUserArg(nameof(TrainCommand.Arguments.DataFile), "Data file must be defined.");
}
/// <summary>
/// If user name is null or empty, return null.
/// Else, if the user name is found in the schema, return the user name.
/// Else, if the user name equals the default name return null.
/// Else, throw an error.
/// </summary>
public static string MatchNameOrDefaultOrNull(IExceptionContext ectx, DataViewSchema schema, string argName, string userName, string defaultName)
{
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(schema, nameof(schema));
ectx.CheckNonEmpty(argName, nameof(argName));
ectx.CheckValueOrNull(userName);
ectx.CheckValue(defaultName, nameof(defaultName));
if (string.IsNullOrWhiteSpace(userName))
return null;
int col;
if (schema.TryGetColumnIndex(userName, out col))
return userName;
if (userName == defaultName)
return null;
#pragma warning disable MSML_ContractsNameUsesNameof
throw ectx.ExceptUserArg(argName, $"Could not find column '{userName}'");
#pragma warning restore MSML_ContractsNameUsesNameof
}
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer,
IComponentFactory<ICalibratorTrainer> calibrator, int maxCalibrationExamples)
{
return TrainCore(env, ch, data, trainer, null, calibrator, maxCalibrationExamples, false);
}
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
IComponentFactory<ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null, RoleMappedData testData = null)
{
return TrainCore(env, ch, data, trainer, validData, calibrator, maxCalibrationExamples, cacheData, inputPredictor, testData);
}
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
IComponentFactory<ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null, RoleMappedData testData = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(data, nameof(data));
ch.CheckValue(trainer, nameof(trainer));
ch.CheckValueOrNull(validData);
ch.CheckValueOrNull(inputPredictor);
AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
ch.Trace("Training");
if (validData != null)
AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
{
ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
": Trainer does not support incremental training.");
inputPredictor = null;
}
ch.Assert(validData == null || trainer.Info.SupportsValidation);
var predictor = trainer.Train(new TrainContext(data, validData, testData, inputPredictor));
var caliTrainer = calibrator?.CreateComponent(env);
return CalibratorUtils.TrainCalibratorIfNeeded(env, ch, caliTrainer, maxCalibrationExamples, trainer, predictor, data);
}
public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor)
{
Contracts.AssertValue(env);
Contracts.AssertValue(ch);
if (!string.IsNullOrEmpty(inputModelFile))
{
ch.Trace("Constructing predictor from input model");
using (var file = env.OpenInputFile(inputModelFile))
using (var strm = file.OpenReadStream())
using (var rep = RepositoryReader.Open(strm, ch))
{
ch.Trace("Loading predictor");
return ModelLoadContext.LoadModelOrNull<IPredictor, SignatureLoadModel>(env, out inputPredictor, rep, ModelFileUtils.DirPredictor);
}
}
inputPredictor = null;
return false;
}
/// <summary>
/// Save the model to the output path.
/// The method saves the loader and the transformations of dataPipe and saves optionally predictor
/// and command. It also uses featureColumn, if provided, to extract feature names.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="ch">The communication channel to use.</param>
/// <param name="output">The output file handle.</param>
/// <param name="predictor">The predictor.</param>
/// <param name="data">The training examples.</param>
/// <param name="command">The command string.</param>
public static void SaveModel(IHostEnvironment env, IChannel ch, IFileHandle output,
IPredictor predictor, RoleMappedData data, string command = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckParam(output != null && output.CanWrite, nameof(output));
ch.CheckValueOrNull(predictor);
ch.CheckValue(data, nameof(data));
ch.CheckValueOrNull(command);
using (var stream = output.CreateWriteStream())
SaveModel(env, ch, stream, predictor, data, command);
}
/// <summary>
/// Save the model to the stream.
/// The method saves the loader and the transformations of dataPipe and saves optionally predictor
/// and command. It also uses featureColumn, if provided, to extract feature names.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="ch">The communication channel to use.</param>
/// <param name="outputStream">The output model stream.</param>
/// <param name="predictor">The predictor.</param>
/// <param name="data">The training examples.</param>
/// <param name="command">The command string.</param>
public static void SaveModel(IHostEnvironment env, IChannel ch, Stream outputStream, IPredictor predictor, RoleMappedData data, string command = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(outputStream, nameof(outputStream));
ch.CheckValueOrNull(predictor);
ch.CheckValue(data, nameof(data));
ch.CheckValueOrNull(command);
using (var ch2 = env.Start("SaveModel"))
using (var pch = env.StartProgressChannel("Saving model"))
{
using (var rep = RepositoryWriter.CreateNew(outputStream, ch2))
{
if (predictor != null)
{
ch2.Trace("Saving predictor");
ModelSaveContext.SaveModel(rep, predictor, ModelFileUtils.DirPredictor);
}
ch2.Trace("Saving loader and transformations");
var dataPipe = data.Data;
if (dataPipe is ILegacyDataLoader)
ModelSaveContext.SaveModel(rep, dataPipe, ModelFileUtils.DirDataLoaderModel);
else
SaveDataPipe(env, rep, dataPipe);
// REVIEW: Handle statistics.
// ModelSaveContext.SaveModel(rep, dataStats, DirDataStats);
if (!string.IsNullOrWhiteSpace(command))
{
using (var ent = rep.CreateEntry(ModelFileUtils.DirTrainingInfo, "Command.txt"))
using (var writer = Utils.OpenWriter(ent.Stream))
writer.WriteLine(command);
}
ModelFileUtils.SaveRoleMappings(env, ch, data.Schema, rep);
rep.Commit();
}
}
}
/// <summary>
/// Save the data pipeline defined by dataPipe. If blankLoader is true or the root IDataView is not an IDataLoader,
/// this persists the root as a BinaryLoader having the same schema.
/// </summary>
public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositoryWriter, IDataView dataPipe, bool blankLoader = false)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(repositoryWriter, nameof(repositoryWriter));
env.CheckValue(dataPipe, nameof(dataPipe));
IDataView pipeStart;
var xfs = BacktrackPipe(dataPipe, out pipeStart);
Action<ModelSaveContext> saveAction;
if (!blankLoader && pipeStart is ILegacyDataLoader loader)
saveAction = loader.Save;
else
{
// The serialized pipe must start with a loader. If the original data view is not a loader,
// we replace it with a binary loader with the correct schema.
saveAction = ctx => BinaryLoader.SaveInstance(env, ctx, pipeStart.Schema);
}
using (var ctx = ModelFileUtils.GetDataModelSavingContext(repositoryWriter))
{
LegacyCompositeDataLoader.SavePipe(env, ctx, saveAction, xfs);
ctx.Done();
}
}
/// <summary>
/// Traces back the .Source chain of the transformation pipe <paramref name="dataPipe"/> up to the moment it no longer can.
/// Returns all the transforms of <see cref="IDataView"/> and the first data view (a non-transform).
/// </summary>
/// <param name="dataPipe">The transformation pipe to traverse.</param>
/// <param name="pipeStart">The beginning data view of the transform chain</param>
/// <returns>The list of the transforms</returns>
private static List<IDataTransform> BacktrackPipe(IDataView dataPipe, out IDataView pipeStart)
{
Contracts.AssertValue(dataPipe);
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
pipeStart = dataPipe;
transforms.Reverse();
return transforms;
}
// Returns true if a normalizer was added.
public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(trainer, nameof(trainer));
ch.CheckValue(view, nameof(view));
ch.CheckValueOrNull(featureColumn);
ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures),
"Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'.");
if (autoNorm == NormalizeOption.No)
{
ch.Info("Not adding a normalizer.");
return false;
}
if (string.IsNullOrEmpty(featureColumn))
return false;
int featCol;
var schema = view.Schema;
if (schema.TryGetColumnIndex(featureColumn, out featCol))
{
if (autoNorm != NormalizeOption.Yes)
{
if (!trainer.Info.NeedNormalization || schema[featCol].IsNormalized())
{
ch.Info("Not adding a normalizer.");
return false;
}
if (autoNorm == NormalizeOption.Warn)
{
ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options.");
return false;
}
}
ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input)
=> NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn);
if (view is ILegacyDataLoader loader)
view = LegacyCompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer);
else
view = ApplyNormalizer(env, view);
return true;
}
return false;
}
private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer trainer, ref RoleMappedData data, bool? cacheData)
{
Contracts.AssertValue(env, nameof(env));
env.AssertValue(ch, nameof(ch));
ch.AssertValue(trainer, nameof(trainer));
ch.AssertValue(data, nameof(data));
bool shouldCache = cacheData ?? !(data.Data is BinaryLoader) && trainer.Info.WantCaching;
if (shouldCache)
{
ch.Trace("Caching");
var prefetch = data.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
var cacheView = new CacheDataView(env, data.Data, prefetch);
// Because the prefetching worked, we know that these are valid columns.
data = new RoleMappedData(cacheView, data.Schema.GetColumnRoleNames());
}
else
ch.Trace("Not caching");
return shouldCache;
}
public static IEnumerable<KeyValuePair<ColumnRole, string>> CheckAndGenerateCustomColumns(IExceptionContext ectx, KeyValuePair<string, string>[] customColumnArg)
{
Contracts.CheckValueOrNull(ectx);
if (customColumnArg == null)
return Enumerable.Empty<KeyValuePair<ColumnRole, string>>();
foreach (var kindName in customColumnArg)
{
ectx.CheckUserArg(!string.IsNullOrWhiteSpace(kindName.Value), nameof(TrainCommand.Arguments.CustomColumns), "Names for columns with custom kind must not be empty");
if (string.IsNullOrWhiteSpace(kindName.Key))
throw ectx.ExceptUserArg(nameof(TrainCommand.Arguments.CustomColumns), "Custom column with name '{0}' needs a kind. Use col[<Kind>]={0}", kindName.Value);
}
return customColumnArg.Select(kindName => new ColumnRole(kindName.Key).Bind(kindName.Value));
}
}
}
|