|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
using TorchSharp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using Microsoft.ML.TorchSharp.Utils;
using System.IO;
using System.Runtime.CompilerServices;
namespace Microsoft.ML.TorchSharp
{
public abstract class TorchSharpBaseTrainer : IEstimator<TorchSharpBaseTransformer>
{
public abstract TorchSharpBaseTransformer Fit(IDataView input);
public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);
public abstract class Options : TransformInputBase
{
/// <summary>
/// The label column name.
/// </summary>
public string LabelColumnName = DefaultColumnNames.Label;
/// <summary>
/// The Score column name.
/// </summary>
public string ScoreColumnName = DefaultColumnNames.Score;
/// <summary>
/// The Prediction column name.
/// </summary>
public string PredictionColumnName = DefaultColumnNames.PredictedLabel;
/// <summary>
/// Number of samples to use for mini-batch training.
/// </summary>
public int BatchSize = 32;
/// <summary>
/// The start learning rate for polynomial decay scheduler.
/// </summary>
public double StartLearningRateRatio = .1;
/// <summary>
/// The final learning rate for polynomial decay scheduler.
/// </summary>
public double FinalLearningRateRatio = .9;
/// <summary>
/// Coefficiency of weight decay. Should be within [0, +Inf).
/// </summary>
public double WeightDecay = 0;
/// <summary>
/// Stop training when reaching this number of epochs.
/// </summary>
public int MaxEpoch = 100;
/// <summary>
/// The validation set used while training to improve model quality.
/// </summary>
public IDataView ValidationSet = null;
/// <summary>
/// Number of classes for the data.
/// </summary>
internal int NumberOfClasses;
}
}
public abstract class TorchSharpBaseTrainer<TLabelCol, TTargetsCol> : TorchSharpBaseTrainer
{
private protected readonly IHost Host;
internal readonly Options Option;
internal TorchSharpBaseTrainer(IHostEnvironment env, Options options)
{
Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(TorchSharpBaseTrainer));
Contracts.Assert(options.BatchSize > 0);
Contracts.Assert(options.MaxEpoch > 0);
Contracts.AssertValue(options.LabelColumnName);
Contracts.AssertValue(options.PredictionColumnName);
Option = options;
}
public override TorchSharpBaseTransformer Fit(IDataView input)
{
CheckInputSchema(SchemaShape.Create(input.Schema));
TorchSharpBaseTransformer<TLabelCol, TTargetsCol> transformer = default;
using (var ch = Host.Start("TrainModel"))
using (var pch = Host.StartProgressChannel("Training model"))
{
var header = new ProgressHeader(new[] { "Accuracy" }, null);
var trainer = CreateTrainer(this, ch, input);
pch.SetHeader(header, e => e.SetMetric(0, trainer.Accuracy));
for (int i = 0; i < Option.MaxEpoch; i++)
{
ch.Trace($"Starting epoch {i}");
Host.CheckAlive();
trainer.Train(Host, input);
ch.Trace($"Finished epoch {i}");
if (Option.ValidationSet != null)
trainer.Validate(pch, ch, i);
}
var labelCol = input.Schema.GetColumnOrNull(Option.LabelColumnName);
transformer = CreateTransformer(Host, Option, trainer.Model, new DataViewSchema.DetachedColumn(labelCol.Value));
trainer.Optimizer.Dispose();
transformer.GetOutputSchema(input.Schema);
}
return transformer;
}
private protected abstract void CheckInputSchema(SchemaShape inputSchema);
private protected abstract TorchSharpBaseTransformer<TLabelCol, TTargetsCol> CreateTransformer(IHost host, TorchSharpBaseTrainer<TLabelCol, TTargetsCol>.Options options, Module model, DataViewSchema.DetachedColumn labelColumn);
private protected abstract TrainerBase CreateTrainer(TorchSharpBaseTrainer<TLabelCol, TTargetsCol> parent, IChannel ch, IDataView input);
internal abstract class TrainerBase
{
public Module Model;
public torch.Device Device;
public optim.Optimizer Optimizer;
public optim.lr_scheduler.LRScheduler LearningRateScheduler;
protected readonly TorchSharpBaseTrainer<TLabelCol, TTargetsCol> Parent;
public int Updates;
public float Accuracy;
public readonly int TrainingRowCount;
protected string ModelUrl;
public TrainerBase(TorchSharpBaseTrainer<TLabelCol, TTargetsCol> parent, IChannel ch, IDataView input, string modelUrl)
{
Parent = parent;
Updates = 0;
Accuracy = 0;
ModelUrl = modelUrl;
// Get row count and figure out num of unique labels
TrainingRowCount = GetRowCountAndSetLabelCount(input);
// Initialize the model and load pre-trained weights
Model = CreateModule(ch, input);
// Figure out if we are running on GPU or CPU
Device = TorchUtils.InitializeDevice(Parent.Host);
// Move to GPU if we are running there
if (Device.type == DeviceType.CUDA)
Model.cuda();
}
private protected abstract int GetRowCountAndSetLabelCount(IDataView input);
private protected abstract Module CreateModule(IChannel ch, IDataView input);
public string GetModelPath(string modelUrl)
{
var destDir = Path.Combine(((IHostEnvironmentInternal)Parent.Host).TempFilePath, "mlnet");
var destFileName = modelUrl.Split('/').Last();
Directory.CreateDirectory(destDir);
string relativeFilePath = Path.Combine(destDir, destFileName);
int timeout = 10 * 60 * 1000;
using (var ch = (Parent.Host as IHostEnvironment).Start("Ensuring model file is present."))
{
var ensureModel = ResourceManagerUtils.Instance.EnsureResourceAsync(Parent.Host, ch, modelUrl, destFileName, destDir, timeout);
ensureModel.Wait();
var errorResult = ResourceManagerUtils.GetErrorMessage(out var errorMessage, ensureModel.Result);
if (errorResult != null)
{
var directory = Path.GetDirectoryName(errorResult.FileName);
var name = Path.GetFileName(errorResult.FileName);
throw ch.Except($"{errorMessage}\nmodel file could not be downloaded!");
}
}
return relativeFilePath;
}
public void Validate(IProgressChannel pch, IChannel ch, int epoch)
{
var validationSet = Parent.Option.ValidationSet;
Model.eval();
DataViewRowCursor cursor = GetRowCursor(validationSet);
InitializeDataGetters(validationSet, cursor);
var labelGetter = cursor.GetGetter<TLabelCol>(validationSet.Schema[Parent.Option.LabelColumnName]);
// Pre-allocate the memory so it's only done once (though this step needs to be optimized)
List<Tensor> inputTensors = new List<Tensor>(Parent.Option.BatchSize);
List<TTargetsCol> targets = new List<TTargetsCol>(Parent.Option.BatchSize);
int numCorrect = 0;
int numRows = 0;
var cursorValid = true;
while (cursorValid)
{
cursorValid = ValidateStep(cursor, labelGetter, ref inputTensors, ref targets, ref numCorrect, ref numRows);
}
Accuracy = numCorrect / (float)numRows;
pch.Checkpoint(Accuracy);
ch.Info($"Accuracy for epoch {epoch}: {Accuracy}");
Model.train();
}
private protected abstract void InitializeDataGetters(IDataView input, DataViewRowCursor cursor);
private bool ValidateStep(DataViewRowCursor cursor,
ValueGetter<TLabelCol> labelGetter,
ref List<Tensor> inputTensors,
ref List<TTargetsCol> targets,
ref int numCorrect,
ref int numRows)
{
// Make sure list is clear before use
inputTensors.Clear();
targets.Clear();
using var disposeScope = torch.NewDisposeScope();
var cursorValid = true;
for (int i = 0; i < Parent.Option.BatchSize && cursorValid; i++)
{
cursorValid = cursor.MoveNext();
if (cursorValid)
{
TLabelCol target = default;
labelGetter(ref target);
inputTensors.Add(PrepareRowTensor(ref target));
targets.Add(AddToTargets(target));
}
else
{
inputTensors.TrimExcess();
targets.TrimExcess();
if (inputTensors.Count() == 0)
return cursorValid;
}
}
using (torch.no_grad())
{
var inputTensor = PrepareBatchTensor(ref inputTensors, device: Device);
var targetsTensor = CreateTargetsTensor(ref targets, device: Device);
RunModelAndUpdateValidationStats(ref inputTensor, ref targetsTensor, ref numCorrect);
numRows = inputTensors.Count;
}
return cursorValid;
}
private protected abstract void RunModelAndUpdateValidationStats(ref Tensor inputTensor, ref Tensor targetsTensor, ref int numCorrect);
public void Train(IHost host, IDataView input)
{
// Get the cursor and the correct columns based on the inputs
DataViewRowCursor cursor = GetRowCursor(input);
InitializeDataGetters(input, cursor);
var labelGetter = cursor.GetGetter<TLabelCol>(input.Schema[Parent.Option.LabelColumnName]);
// Pre-allocate the memory so it's only done once (though this step needs to be optimized)
List<Tensor> inputTensors = new List<Tensor>(Parent.Option.BatchSize);
List<TTargetsCol> targets = new List<TTargetsCol>(Parent.Option.BatchSize);
if (host is IHostEnvironmentInternal hostInternal)
{
torch.random.manual_seed(hostInternal.Seed ?? 1);
torch.cuda.manual_seed(hostInternal.Seed ?? 1);
}
else
{
torch.random.manual_seed(1);
torch.cuda.manual_seed(1);
}
var cursorValid = true;
while (cursorValid)
{
cursorValid = TrainStep(host, cursor, labelGetter, ref inputTensors, ref targets);
}
}
private bool TrainStep(IHost host,
DataViewRowCursor cursor,
ValueGetter<TLabelCol> labelGetter,
ref List<Tensor> inputTensors,
ref List<TTargetsCol> targets)
{
// Make sure list is clear before use
inputTensors.Clear();
targets.Clear();
using var disposeScope = torch.NewDisposeScope();
var cursorValid = true;
for (int i = 0; i < Parent.Option.BatchSize && cursorValid; i++)
{
host.CheckAlive();
cursorValid = cursor.MoveNext();
if (cursorValid)
{
TLabelCol target = default;
labelGetter(ref target);
inputTensors.Add(PrepareRowTensor(ref target));
targets.Add(AddToTargets(target));
}
else
{
inputTensors.TrimExcess();
targets.TrimExcess();
if (inputTensors.Count() == 0)
return cursorValid;
}
}
Updates++;
host.CheckAlive();
Model.train();
Optimizer.zero_grad();
var targetsTensor = CreateTargetsTensor(ref targets, device: Device);
RunModelAndBackPropagate(ref inputTensors, ref targetsTensor);
host.CheckAlive();
OptimizeStep();
return cursorValid;
}
private protected abstract void RunModelAndBackPropagate(ref List<Tensor> inputTensorm, ref Tensor targetsTensor);
private protected abstract torch.Tensor PrepareRowTensor(ref TLabelCol target);
private protected abstract torch.Tensor PrepareBatchTensor(ref List<Tensor> inputTensors, Device device);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private protected abstract TTargetsCol AddToTargets(TLabelCol target);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private protected abstract Tensor CreateTargetsTensor(ref List<TTargetsCol> targets, Device device);
private protected abstract DataViewRowCursor GetRowCursor(IDataView input);
private protected abstract torch.Tensor GetPredictions(torch.Tensor logits);
private protected abstract torch.Tensor GetTargets(torch.Tensor labels);
private protected abstract int GetNumCorrect(torch.Tensor predictions, torch.Tensor targets);
private protected virtual void OptimizeStep()
{
Optimizer.step();
LearningRateScheduler.step();
}
}
}
public abstract class TorchSharpBaseTransformer : RowToRowTransformerBase, IDisposable
{
private protected TorchSharpBaseTransformer(IHost host) : base(host)
{
}
public abstract void Dispose();
}
public abstract class TorchSharpBaseTransformer<TLabelCol, TTargetsCol> : TorchSharpBaseTransformer
{
private protected readonly Device Device;
private protected readonly Module Model;
internal readonly TorchSharpBaseTrainer.Options Options;
private protected readonly string ScoreColumnName;
public readonly DataViewSchema.DetachedColumn LabelColumn;
private bool _disposedValue;
internal TorchSharpBaseTransformer(IHostEnvironment env, TorchSharpBaseTrainer.Options options, Module model, DataViewSchema.DetachedColumn labelColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TorchSharpBaseTransformer)))
{
Device = TorchUtils.InitializeDevice(env);
Options = options;
LabelColumn = labelColumn;
ScoreColumnName = Options.ScoreColumnName;
Model = model;
if (Device.type == DeviceType.CUDA)
Model.cuda();
}
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
CheckInputSchema(inputSchema);
return GetOutputSchemaCore(inputSchema);
}
private protected abstract void CheckInputSchema(SchemaShape inputSchema);
private protected abstract SchemaShape GetOutputSchemaCore(SchemaShape inputSchema);
private protected abstract override void SaveModel(ModelSaveContext ctx);
private protected void SaveBaseModel(ModelSaveContext ctx, VersionInfo versionInfo)
{
Host.AssertValue(ctx);
ctx.CheckAtModel();
ctx.SetVersionInfo(versionInfo);
// *** Binary format ***
// int: id of label column name
// int: id of the score column name
// int: id of output column name
// int: number of classes
// BinaryStream: TS Model
ctx.SaveNonEmptyString(Options.LabelColumnName);
ctx.SaveStringOrNull(Options.ScoreColumnName);
ctx.SaveNonEmptyString(Options.PredictionColumnName);
ctx.Writer.Write(Options.NumberOfClasses);
ctx.SaveBinaryStream("TSModel", w =>
{
Model.save(w);
});
}
private protected abstract IRowMapper GetRowMapper(TorchSharpBaseTransformer<TLabelCol, TTargetsCol> parent, DataViewSchema schema);
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => GetRowMapper(this, schema);
private protected abstract class TorchSharpBaseMapper : MapperBase
{
private protected readonly TorchSharpBaseTransformer<TLabelCol, TTargetsCol> Parent;
private protected readonly HashSet<int> InputColIndices;
private static readonly FuncInstanceMethodInfo1<TorchSharpBaseMapper, DataViewSchema.DetachedColumn, Delegate> _makeLabelAnnotationGetter
= FuncInstanceMethodInfo1<TorchSharpBaseMapper, DataViewSchema.DetachedColumn, Delegate>.Create(target => target.GetLabelAnnotations<int>);
private Delegate GetLabelAnnotations<T>(DataViewSchema.DetachedColumn labelCol)
{
return labelCol.Annotations.GetGetter<VBuffer<T>>(labelCol.Annotations.Schema[AnnotationUtils.Kinds.KeyValues]);
}
public TorchSharpBaseMapper(TorchSharpBaseTransformer<TLabelCol, TTargetsCol> parent, DataViewSchema inputSchema) :
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(TorchSharpBaseMapper)), inputSchema, parent)
{
Parent = parent;
InputColIndices = new HashSet<int>();
AddInputColumnIndices(inputSchema);
if (Host is IHostEnvironmentInternal hostInternal)
{
torch.random.manual_seed(hostInternal.Seed ?? 1);
torch.cuda.manual_seed(hostInternal.Seed ?? 1);
}
else
{
torch.random.manual_seed(1);
torch.cuda.manual_seed(1);
}
}
private protected abstract void AddInputColumnIndices(DataViewSchema inputSchema);
private protected ValueGetter<uint> GetScoreColumnSetId(DataViewSchema schema)
{
int c;
var max = schema.GetMaxAnnotationKind(out c, AnnotationUtils.Kinds.ScoreColumnSetId);
uint id = checked(max + 1);
return
(ref uint dst) => dst = id;
}
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
=> throw new NotImplementedException("This should never be called!");
private protected abstract Delegate CreateGetter(DataViewRow input, int iinfo, TensorCacher outputCacher);
public override Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
{
Host.AssertValue(input);
Contracts.Assert(input.Schema == base.InputSchema);
TensorCacher outputCacher = GetTensorCacher();
var ch = Host.Start("Make Getters");
Parent.Model.eval();
int n = OutputColumns.Value.Length;
var result = new Delegate[n];
for (int i = 0; i < n; i++)
{
if (!activeOutput(i))
continue;
result[i] = CreateGetter(input, i, outputCacher);
}
disposer = () =>
{
outputCacher.Dispose();
};
return result;
}
private protected abstract TensorCacher GetTensorCacher();
private protected abstract class TensorCacher : IDisposable
{
public long Position;
public TensorCacher()
{
Position = -1;
}
public abstract void Dispose();
public abstract void DisposeCore();
}
private protected abstract class TensorCacher<TOut> : TensorCacher
{
public TOut Result;
public TensorCacher() : base()
{
Result = default;
}
private bool _isDisposed;
public override void Dispose()
{
if (_isDisposed)
return;
DisposeCore();
_isDisposed = true;
}
}
private protected override void SaveModel(ModelSaveContext ctx) => Parent.SaveModel(ctx);
}
protected virtual void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
// TODO: dispose managed state (managed objects)
}
Model.Dispose();
_disposedValue = true;
}
}
~TorchSharpBaseTransformer()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: false);
}
public override void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
}
|