|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
using TorchSharp;
using static TorchSharp.torch;
using static TorchSharp.TensorExtensionMethods;
using static TorchSharp.torch.optim;
using static TorchSharp.torch.optim.lr_scheduler;
using Microsoft.ML.TorchSharp.Utils;
using Microsoft.ML;
using System.IO;
using Microsoft.ML.Data.IO;
using Microsoft.ML.TorchSharp.Loss;
using Microsoft.ML.Transforms.Image;
using static Microsoft.ML.TorchSharp.AutoFormerV2.ObjectDetectionTrainer;
using Microsoft.ML.TorchSharp.AutoFormerV2;
using static Microsoft.ML.Data.AnnotationUtils;
[assembly: LoadableClass(typeof(ObjectDetectionTransformer), null, typeof(SignatureLoadModel),
ObjectDetectionTransformer.UserName, ObjectDetectionTransformer.LoaderSignature)]
[assembly: LoadableClass(typeof(IRowMapper), typeof(ObjectDetectionTransformer), null, typeof(SignatureLoadRowMapper),
ObjectDetectionTransformer.UserName, ObjectDetectionTransformer.LoaderSignature)]
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
public class ObjectDetectionTrainer : IEstimator<ObjectDetectionTransformer>
{
public sealed class Options : TransformInputBase
{
/// <summary>
/// The label column name.
/// </summary>
public string LabelColumnName = DefaultColumnNames.Label;
/// <summary>
/// The label column name.
/// </summary>
public string PredictedLabelColumnName = DefaultColumnNames.PredictedLabel;
/// <summary>
/// The Bounding Box column name.
/// </summary>
public string BoundingBoxColumnName = "BoundingBoxes";
/// <summary>
/// The Predicted Bounding Box column name.
/// </summary>
public string PredictedBoundingBoxColumnName = "PredictedBoundingBoxes";
/// <summary>
/// The Image column name.
/// </summary>
public string ImageColumnName = "Image";
/// <summary>
/// The Confidence column name.
/// </summary>
public string ScoreColumnName = DefaultColumnNames.Score;
/// <summary>
/// Gets or sets the IOU threshold for removing duplicate bounding boxes.
/// </summary>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "<Pending>")]
public double IOUThreshold = 0.5;
/// <summary>
/// Gets or sets the confidenct threshold for bounding box category.
/// </summary>
public double ScoreThreshold = 0.5;
/// <summary>
/// Gets or sets the epoch steps in learning rate scheduler to reduce learning rate.
/// </summary>
public List<int> Steps = new List<int> { 6 };
/// <summary>
/// Stop training when reaching this number of epochs.
/// </summary>
public int MaxEpoch = 10;
/// <summary>
/// The validation set used while training to improve model quality.
/// </summary>
public IDataView ValidationSet = null;
/// <summary>
/// Number of classes for the data.
/// </summary>
internal int NumberOfClasses;
/// <summary>
/// Gets or sets the initial learning rate in optimizer.
/// </summary>
public double InitLearningRate = 1.0;
/// <summary>
/// Gets or sets the weight decay in optimizer.
/// </summary>
public double WeightDecay = 0.0;
/// <summary>
/// How often to log the loss.
/// </summary>
public int LogEveryNStep = 50;
}
private protected readonly IHost Host;
internal readonly Options Option;
private const string ModelUrl = "models/autoformer_11m_torchsharp.bin";
internal ObjectDetectionTrainer(IHostEnvironment env, Options options)
{
Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ObjectDetectionTrainer));
Contracts.Assert(options.MaxEpoch > 0);
Contracts.AssertValue(options.BoundingBoxColumnName);
Contracts.AssertValue(options.LabelColumnName);
Contracts.AssertValue(options.ImageColumnName);
Contracts.AssertValue(options.ScoreColumnName);
Contracts.AssertValue(options.PredictedLabelColumnName);
Option = options;
}
internal ObjectDetectionTrainer(IHostEnvironment env,
string labelColumnName = DefaultColumnNames.Label,
string predictedLabelColumnName = DefaultColumnNames.PredictedLabel,
string scoreColumnName = DefaultColumnNames.Score,
string boundingBoxColumnName = "BoundingBoxes",
string predictedBoundingBoxColumnName = "PredictedBoundingBoxes",
string imageColumnName = "Image",
int maxEpoch = 10) :
this(env, new Options
{
LabelColumnName = labelColumnName,
PredictedLabelColumnName = predictedLabelColumnName,
ScoreColumnName = scoreColumnName,
BoundingBoxColumnName = boundingBoxColumnName,
PredictedBoundingBoxColumnName = predictedBoundingBoxColumnName,
ImageColumnName = imageColumnName,
MaxEpoch = maxEpoch
})
{
}
public ObjectDetectionTransformer Fit(IDataView input)
{
CheckInputSchema(SchemaShape.Create(input.Schema));
ObjectDetectionTransformer transformer = default;
using (var ch = Host.Start("TrainModel"))
using (var pch = Host.StartProgressChannel("Training model"))
{
var header = new ProgressHeader(new[] { "Loss" }, new[] { "total images" });
var trainer = new Trainer(this, ch, input);
pch.SetHeader(header,
e =>
{
e.SetProgress(0, trainer.Updates, trainer.RowCount);
e.SetMetric(0, trainer.LossValue);
});
for (int i = 0; i < Option.MaxEpoch; i++)
{
ch.Trace($"Starting epoch {i}");
Host.CheckAlive();
trainer.Train(Host, input, pch);
ch.Trace($"Finished epoch {i}");
}
var labelCol = input.Schema.GetColumnOrNull(Option.LabelColumnName);
transformer = new ObjectDetectionTransformer(Host, Option, trainer.Model, new DataViewSchema.DetachedColumn(labelCol.Value));
trainer.Optimizer.Dispose();
transformer.GetOutputSchema(input.Schema);
}
return transformer;
}
internal class Trainer
{
public AutoFormerV2 Model;
public torch.Device Device;
public Optimizer Optimizer;
public LRScheduler LearningRateScheduler;
protected readonly ObjectDetectionTrainer Parent;
public FocalLoss Loss;
public int Updates;
public float LossValue;
public readonly int RowCount;
private readonly IChannel _channel;
public Trainer(ObjectDetectionTrainer parent, IChannel ch, IDataView input)
{
Parent = parent;
Updates = 0;
LossValue = 0;
_channel = ch;
// Get row count and figure out num of unique labels
RowCount = GetRowCountAndSetLabelCount(input);
Device = TorchUtils.InitializeDevice(Parent.Host);
// Initialize the model and load pre-trained weights
Model = new AutoFormerV2(
Parent.Option.NumberOfClasses,
embedChannels: new List<int>() { 64, 128, 256, 448 },
depths: new List<int>() { 2, 2, 6, 2 },
numHeads: new List<int>() { 2, 4, 8, 14 },
device: Device);
Model.load(GetModelPath(), false);
// Figure out if we are running on GPU or CPU
Device = TorchUtils.InitializeDevice(Parent.Host);
// Move to GPU if we are running there
if (Device.type == DeviceType.CUDA)
Model.cuda();
// Get the parameters that need optimization and set up the optimizer
Optimizer = SGD(
Model.parameters(),
learningRate: Parent.Option.InitLearningRate,
weight_decay: Parent.Option.WeightDecay);
Loss = new FocalLoss();
LearningRateScheduler = MultiStepLR(Optimizer, Parent.Option.Steps);
}
private protected int GetRowCountAndSetLabelCount(IDataView input)
{
var labelCol = input.GetColumn<VBuffer<uint>>(Parent.Option.LabelColumnName);
var rowCount = 0;
var uniqueLabels = new HashSet<uint>();
foreach (var label in labelCol)
{
rowCount++;
label.DenseValues().ToList().ForEach(x => uniqueLabels.Add(x));
}
Parent.Option.NumberOfClasses = uniqueLabels.Count;
return rowCount;
}
private string GetModelPath()
{
var destDir = Path.Combine(((IHostEnvironmentInternal)Parent.Host).TempFilePath, "mlnet");
var destFileName = ModelUrl.Split('/').Last();
Directory.CreateDirectory(destDir);
string relativeFilePath = Path.Combine(destDir, destFileName);
int timeout = 10 * 60 * 1000;
using (var ch = (Parent.Host as IHostEnvironment).Start("Ensuring model file is present."))
{
var ensureModel = ResourceManagerUtils.Instance.EnsureResourceAsync(Parent.Host, ch, ModelUrl, destFileName, destDir, timeout);
ensureModel.Wait();
var errorResult = ResourceManagerUtils.GetErrorMessage(out var errorMessage, ensureModel.Result);
if (errorResult != null)
{
var directory = Path.GetDirectoryName(errorResult.FileName);
var name = Path.GetFileName(errorResult.FileName);
throw ch.Except($"{errorMessage}\nmodel file could not be downloaded!");
}
}
return relativeFilePath;
}
public void Train(IHost host, IDataView input, IProgressChannel pch)
{
// Get the cursor and the correct columns based on the inputs
DataViewRowCursor cursor = input.GetRowCursor(input.Schema[Parent.Option.LabelColumnName], input.Schema[Parent.Option.BoundingBoxColumnName], input.Schema[Parent.Option.ImageColumnName]);
var boundingBoxGetter = cursor.GetGetter<VBuffer<float>>(input.Schema[Parent.Option.BoundingBoxColumnName]);
var imageGetter = cursor.GetGetter<MLImage>(input.Schema[Parent.Option.ImageColumnName]);
var labelGetter = cursor.GetGetter<VBuffer<uint>>(input.Schema[Parent.Option.LabelColumnName]);
var cursorValid = true;
Updates = 0;
Model.train();
Model.FreezeBN();
if (host is IHostEnvironmentInternal hostInternal)
{
torch.random.manual_seed(hostInternal.Seed ?? 1);
torch.cuda.manual_seed(hostInternal.Seed ?? 1);
}
else
{
torch.random.manual_seed(1);
torch.cuda.manual_seed(1);
}
while (cursorValid)
{
cursorValid = TrainStep(host, cursor, boundingBoxGetter, imageGetter, labelGetter, pch);
}
LearningRateScheduler.step();
}
private bool TrainStep(IHost host,
DataViewRowCursor cursor,
ValueGetter<VBuffer<float>> boundingBoxGetter,
ValueGetter<MLImage> imageGetter,
ValueGetter<VBuffer<uint>> labelGetter,
IProgressChannel pch)
{
using var disposeScope = torch.NewDisposeScope();
var cursorValid = true;
Tensor imageTensor = default;
Tensor targetTensor = default;
host.CheckAlive();
cursorValid = cursor.MoveNext();
if (cursorValid)
{
(imageTensor, targetTensor) = PrepareData(labelGetter, imageGetter, boundingBoxGetter);
}
else
{
return cursorValid;
}
Updates++;
host.CheckAlive();
Optimizer.zero_grad();
var (classification, regression, anchors) = Model.forward(imageTensor);
var lossValue = Loss.forward(classification, regression, anchors, targetTensor);
lossValue.backward();
torch.nn.utils.clip_grad_norm_(Model.parameters(), 0.1);
Optimizer.step();
host.CheckAlive();
if (Updates % Parent.Option.LogEveryNStep == 0)
{
pch.Checkpoint(lossValue.ToDouble(), Updates);
_channel.Info($"Row: {Updates}, Loss: {lossValue.ToDouble()}");
}
return cursorValid;
}
private (Tensor image, Tensor Label) PrepareData(ValueGetter<VBuffer<uint>> labelGetter, ValueGetter<MLImage> imageGetter, ValueGetter<VBuffer<float>> boundingBoxGetter)
{
using (var _ = torch.NewDisposeScope())
{
MLImage image = default;
imageGetter(ref image);
var midTensor0 = torch.tensor(image.GetBGRPixels, device: Device);
var midTensor1 = midTensor0.@float();
var midTensor2 = midTensor1.reshape(1, image.Height, image.Width, 3);
var midTensor3 = midTensor2.transpose(0, 3);
var midTensor4 = midTensor3.reshape(3, image.Height, image.Width);
var chunks = midTensor4.chunk(3, 0);
List<Tensor> part = new List<Tensor>();
part.Add(chunks[2]);
part.Add(chunks[1]);
part.Add(chunks[0]);
using var midTensor = torch.cat(part, 0);
using var reMidTensor = midTensor.reshape(1, 3, image.Height, image.Width);
var padW = 32 - (image.Width % 32);
var padH = 32 - (image.Height % 32);
using var transMidTensor = torch.zeros(1, 3, image.Height + padH, image.Width + padW, device: Device);
transMidTensor[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..image.Height), RangeUtil.ToTensorIndex(..image.Width)] = reMidTensor / 255.0;
var imageTensor = Normalize(transMidTensor, Device);
VBuffer<uint> labels = default;
labelGetter(ref labels);
VBuffer<float> boxes = default;
boundingBoxGetter(ref boxes);
var labelValues = labels.GetValues();
var boxValues = boxes.GetValues();
Contracts.Assert(boxValues.Length == labelValues.Length * 4, "Must have 4 coordinates for each label");
int b = 0;
var labelTensor = torch.zeros(1, labels.Length, 5, dtype: ScalarType.Int64, device: Device);
for (int i = 0; i < labels.Length; i++)
{
long x0 = (long)boxValues[b++];
long y0 = (long)boxValues[b++];
long x1 = (long)boxValues[b++];
long y1 = (long)boxValues[b++];
// Our labels are 1 based, the TorchSharp model is 0 based so subtract 1 to they align correctly.
long cl = labelValues[i] - 1;
labelTensor[RangeUtil.ToTensorIndex(..), i, 0] = x0;
labelTensor[RangeUtil.ToTensorIndex(..), i, 1] = y0;
labelTensor[RangeUtil.ToTensorIndex(..), i, 2] = x1;
labelTensor[RangeUtil.ToTensorIndex(..), i, 3] = y1;
labelTensor[RangeUtil.ToTensorIndex(..), i, 4] = cl;
}
return (imageTensor.MoveToOuterDisposeScope(), labelTensor.MoveToOuterDisposeScope());
}
}
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "<Pending>")]
private static readonly double[] MEAN = { 0.406, 0.456, 0.485 };
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "<Pending>")]
private static readonly double[] STD = { 0.225, 0.224, 0.229 };
internal static Tensor Normalize(Tensor x, Device device)
{
using (var _ = torch.NewDisposeScope())
{
var meanTensor = MEAN.ToTensor(new long[4] { 1L, MEAN.Length, 1L, 1L }).to_type(ScalarType.Float32).to(device);
var stdTensor = STD.ToTensor(new long[4] { 1L, STD.Length, 1L, 1L }).to_type(ScalarType.Float32).to(device);
x = (x - meanTensor) / stdTensor;
return x.MoveToOuterDisposeScope();
}
}
}
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
CheckInputSchema(inputSchema);
var outColumns = inputSchema.ToDictionary(x => x.Name);
var metadata = new List<SchemaShape.Column>();
metadata.Add(new SchemaShape.Column(Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
TextDataViewType.Instance, false));
var scoreMetadata = new List<SchemaShape.Column>();
scoreMetadata.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar,
TextDataViewType.Instance, false));
scoreMetadata.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar,
TextDataViewType.Instance, false));
scoreMetadata.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar,
NumberDataViewType.UInt32, true));
scoreMetadata.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector,
TextDataViewType.Instance, false));
// Get label column for score column annotations. Already verified it exists.
inputSchema.TryFindColumn(Option.LabelColumnName, out var labelCol);
outColumns[Option.PredictedLabelColumnName] = new SchemaShape.Column(Option.PredictedLabelColumnName, SchemaShape.Column.VectorKind.VariableVector,
NumberDataViewType.UInt32, true, new SchemaShape(metadata.ToArray()));
outColumns[Option.PredictedBoundingBoxColumnName] = new SchemaShape.Column(Option.PredictedBoundingBoxColumnName, SchemaShape.Column.VectorKind.VariableVector,
NumberDataViewType.Single, false);
outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.VariableVector,
NumberDataViewType.Single, false, new SchemaShape(scoreMetadata.ToArray()));
return new SchemaShape(outColumns.Values);
}
private void CheckInputSchema(SchemaShape inputSchema)
{
// Verify that all required input columns are present, and are of the same type.
if (!inputSchema.TryFindColumn(Option.LabelColumnName, out var labelCol))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName);
if (labelCol.Kind != SchemaShape.Column.VectorKind.VariableVector || labelCol.ItemType.RawType != typeof(UInt32))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
new VectorDataViewType(new KeyDataViewType(typeof(uint), uint.MaxValue)).ToString(), labelCol.GetTypeString());
if (!inputSchema.TryFindColumn(Option.BoundingBoxColumnName, out var boundingBoxCol))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "BoundingBox", Option.BoundingBoxColumnName);
if (boundingBoxCol.Kind != SchemaShape.Column.VectorKind.VariableVector || boundingBoxCol.ItemType.RawType != typeof(Single))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "BoundingBox", Option.BoundingBoxColumnName,
new VectorDataViewType(NumberDataViewType.Single).ToString(), boundingBoxCol.GetTypeString());
if (!inputSchema.TryFindColumn(Option.ImageColumnName, out var imageCol))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Image", Option.ImageColumnName);
if (imageCol.ItemType.RawType != typeof(MLImage))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Image", Option.ImageColumnName,
new ImageDataViewType().ToString(), imageCol.GetTypeString());
}
}
public class ObjectDetectionTransformer : RowToRowTransformerBase, IDisposable
{
private protected readonly Device Device;
private protected readonly AutoFormerV2 Model;
internal readonly ObjectDetectionTrainer.Options Options;
public readonly SchemaShape.Column PredictedLabelColumnName;
public readonly SchemaShape.Column PredictedBoundingBoxColumn;
public readonly SchemaShape.Column ConfidenceColumn;
public readonly DataViewSchema.DetachedColumn LabelColumn;
internal const string LoadName = "ObjDetTrainer";
internal const string UserName = "Obj Detection Trainer";
internal const string ShortName = "OBJDETC";
internal const string Summary = "Object Detection";
internal const string LoaderSignature = "OBJDETC";
private static readonly FuncStaticMethodInfo1<object, Delegate> _decodeInitMethodInfo
= new FuncStaticMethodInfo1<object, Delegate>(DecodeInit<int>);
private bool _disposedValue;
internal ObjectDetectionTransformer(IHostEnvironment env, ObjectDetectionTrainer.Options options, AutoFormerV2 model, DataViewSchema.DetachedColumn labelColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ObjectDetectionTransformer)))
{
Device = TorchUtils.InitializeDevice(env);
Options = options;
LabelColumn = labelColumn;
PredictedLabelColumnName = new SchemaShape.Column(Options.PredictedLabelColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.UInt32, false);
PredictedBoundingBoxColumn = new SchemaShape.Column(Options.PredictedBoundingBoxColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
ConfidenceColumn = new SchemaShape.Column(Options.ScoreColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
Model = model;
Model.eval();
if (Device.type == DeviceType.CUDA)
Model.cuda();
}
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
CheckInputSchema(inputSchema);
var outColumns = inputSchema.ToDictionary(x => x.Name);
var labelAnnotationsColumn = new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, LabelColumn.Annotations.Schema[AnnotationUtils.Kinds.SlotNames].Type, false);
var predLabelMetadata = new SchemaShape(new SchemaShape.Column[] { labelAnnotationsColumn }
.Concat(AnnotationUtils.GetTrainerOutputAnnotation()));
var scoreMetadata = new List<SchemaShape.Column>();
scoreMetadata.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar,
TextDataViewType.Instance, false));
scoreMetadata.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar,
TextDataViewType.Instance, false));
scoreMetadata.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar,
NumberDataViewType.UInt32, true));
scoreMetadata.Add(new SchemaShape.Column(Kinds.TrainingLabelValues, SchemaShape.Column.VectorKind.Vector,
TextDataViewType.Instance, false));
outColumns[Options.PredictedLabelColumnName] = new SchemaShape.Column(Options.PredictedLabelColumnName, SchemaShape.Column.VectorKind.VariableVector,
NumberDataViewType.UInt32, true, predLabelMetadata);
outColumns[Options.PredictedBoundingBoxColumnName] = new SchemaShape.Column(Options.PredictedBoundingBoxColumnName, SchemaShape.Column.VectorKind.VariableVector,
NumberDataViewType.Single, false);
outColumns[Options.ScoreColumnName] = new SchemaShape.Column(Options.ScoreColumnName, SchemaShape.Column.VectorKind.VariableVector,
NumberDataViewType.Single, false, new SchemaShape(scoreMetadata.ToArray()));
return new SchemaShape(outColumns.Values);
}
private void CheckInputSchema(SchemaShape inputSchema)
{
if (!inputSchema.TryFindColumn(Options.ImageColumnName, out var imageCol))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Image", Options.ImageColumnName);
if (imageCol.ItemType != new ImageDataViewType())
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Image", Options.ImageColumnName,
new ImageDataViewType().ToString(), imageCol.GetTypeString());
}
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "OBJ-DETC",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(ObjectDetectionTransformer).Assembly.FullName);
}
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.AssertValue(ctx);
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
// int: id of label column name
// int: id of predicted label column name
// int: id of the BoundingBoxColumnName name
// int: id of the PredictedBoundingBoxColumnName name
// int: id of ImageColumnName name
// int: id of Score column name
// int: number of classes
// double: score threshold
// double: iou threshold
// LabelValues
// BinaryStream: TS Model
ctx.SaveNonEmptyString(Options.LabelColumnName);
ctx.SaveNonEmptyString(Options.PredictedLabelColumnName);
ctx.SaveNonEmptyString(Options.BoundingBoxColumnName);
ctx.SaveNonEmptyString(Options.PredictedBoundingBoxColumnName);
ctx.SaveNonEmptyString(Options.ImageColumnName);
ctx.SaveNonEmptyString(Options.ScoreColumnName);
ctx.Writer.Write(Options.NumberOfClasses);
ctx.Writer.Write(Options.ScoreThreshold);
ctx.Writer.Write(Options.IOUThreshold);
var labelColType = LabelColumn.Annotations.Schema[AnnotationUtils.Kinds.KeyValues].Type as VectorDataViewType;
Microsoft.ML.Internal.Utilities.Utils.MarshalActionInvoke(SaveLabelValues<int>, labelColType.ItemType.RawType, ctx);
ctx.SaveBinaryStream("TSModel", w =>
{
Model.save(w);
});
}
private void SaveLabelValues<T>(ModelSaveContext ctx)
{
ValueGetter<VBuffer<T>> getter = LabelColumn.Annotations.GetGetter<VBuffer<T>>(LabelColumn.Annotations.Schema[AnnotationUtils.Kinds.KeyValues]);
var val = default(VBuffer<T>);
getter(ref val);
BinarySaver saver = new BinarySaver(Host, new BinarySaver.Arguments());
int bytesWritten;
var labelColType = LabelColumn.Annotations.Schema[AnnotationUtils.Kinds.KeyValues].Type as VectorDataViewType;
if (!saver.TryWriteTypeAndValue<VBuffer<T>>(ctx.Writer.BaseStream, labelColType, ref val, out bytesWritten))
throw Host.Except("We do not know how to serialize label names of type '{0}'", labelColType.ItemType);
}
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new ObjDetMapper(this, schema);
//Factory method for SignatureLoadRowMapper.
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
// Factory method for SignatureLoadModel.
private static ObjectDetectionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
// *** Binary format ***
// int: id of label column name
// int: id of predicted label column name
// int: id of the BoundingBoxColumnName name
// int: id of the PredictedBoundingBoxColumnName name
// int: id of ImageColumnName name
// int: id of Score column name
// int: number of classes
// double: score threshold
// double: iou threshold
// LabelValues
// BinaryStream: TS Model
var options = new Options()
{
LabelColumnName = ctx.LoadString(),
PredictedLabelColumnName = ctx.LoadString(),
BoundingBoxColumnName = ctx.LoadString(),
PredictedBoundingBoxColumnName = ctx.LoadString(),
ImageColumnName = ctx.LoadString(),
ScoreColumnName = ctx.LoadString(),
NumberOfClasses = ctx.Reader.ReadInt32(),
ScoreThreshold = ctx.Reader.ReadDouble(),
IOUThreshold = ctx.Reader.ReadDouble(),
};
var ch = env.Start("Load Model");
var model = new AutoFormerV2(options.NumberOfClasses,
embedChannels: new List<int>() { 64, 128, 256, 448 },
depths: new List<int>() { 2, 2, 6, 2 },
numHeads: new List<int>() { 2, 4, 8, 14 },
device: TorchUtils.InitializeDevice(env));
BinarySaver saver = new BinarySaver(env, new BinarySaver.Arguments());
DataViewType type;
object value;
env.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value));
var vecType = type as VectorDataViewType;
env.CheckDecode(vecType != null);
env.CheckDecode(value != null);
var labelGetter = Microsoft.ML.Internal.Utilities.Utils.MarshalInvoke(_decodeInitMethodInfo, vecType.ItemType.RawType, value);
var meta = new DataViewSchema.Annotations.Builder();
meta.Add(AnnotationUtils.Kinds.KeyValues, type, labelGetter);
var labelCol = new DataViewSchema.DetachedColumn(options.LabelColumnName, type, meta.ToAnnotations());
if (!ctx.TryLoadBinaryStream("TSModel", r => model.load(r)))
throw env.ExceptDecode();
return new ObjectDetectionTransformer(env, options, model, labelCol);
}
private static Delegate DecodeInit<T>(object value)
{
VBuffer<T> buffValue = (VBuffer<T>)value;
ValueGetter<VBuffer<T>> buffGetter = (ref VBuffer<T> dst) => buffValue.CopyTo(ref dst);
return buffGetter;
}
private class ObjDetMapper : MapperBase
{
private readonly ObjectDetectionTransformer _parent;
private readonly HashSet<int> _inputColIndices;
private static readonly FuncInstanceMethodInfo1<ObjDetMapper, DataViewSchema.DetachedColumn, Delegate> _makeLabelAnnotationGetter
= FuncInstanceMethodInfo1<ObjDetMapper, DataViewSchema.DetachedColumn, Delegate>.Create(target => target.GetLabelAnnotations<int>);
public ObjDetMapper(ObjectDetectionTransformer parent, DataViewSchema inputSchema) :
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(ObjDetMapper)), inputSchema, parent)
{
_parent = parent;
_inputColIndices = new HashSet<int>();
if (inputSchema.TryGetColumnIndex(parent.Options.ImageColumnName, out var col))
_inputColIndices.Add(col);
if (Host is IHostEnvironmentInternal hostInternal)
{
torch.random.manual_seed(hostInternal.Seed ?? 1);
torch.cuda.manual_seed(hostInternal.Seed ?? 1);
}
else
{
torch.random.manual_seed(1);
torch.cuda.manual_seed(1);
}
}
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var info = new DataViewSchema.DetachedColumn[3];
var keyType = _parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
var getter = Microsoft.ML.Internal.Utilities.Utils.MarshalInvoke(_makeLabelAnnotationGetter, this, keyType.ItemType.RawType, _parent.LabelColumn);
var meta = new DataViewSchema.Annotations.Builder();
meta.Add(AnnotationUtils.Kinds.ScoreColumnKind, TextDataViewType.Instance, (ref ReadOnlyMemory<char> value) => { value = AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification.AsMemory(); });
meta.Add(AnnotationUtils.Kinds.ScoreColumnSetId, AnnotationUtils.ScoreColumnSetIdType, GetScoreColumnSetId(InputSchema));
meta.Add(AnnotationUtils.Kinds.ScoreValueKind, TextDataViewType.Instance, (ref ReadOnlyMemory<char> value) => { value = AnnotationUtils.Const.ScoreValueKind.Score.AsMemory(); });
meta.Add(AnnotationUtils.Kinds.TrainingLabelValues, keyType, getter);
var labelBuilder = new DataViewSchema.Annotations.Builder();
labelBuilder.Add(AnnotationUtils.Kinds.KeyValues, keyType, getter);
info[0] = new DataViewSchema.DetachedColumn(_parent.Options.PredictedLabelColumnName, new VectorDataViewType(new KeyDataViewType(typeof(uint), _parent.Options.NumberOfClasses)), labelBuilder.ToAnnotations());
info[1] = new DataViewSchema.DetachedColumn(_parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single), meta.ToAnnotations());
info[2] = new DataViewSchema.DetachedColumn(_parent.Options.PredictedBoundingBoxColumnName, new VectorDataViewType(NumberDataViewType.Single));
return info;
}
private Delegate GetLabelAnnotations<T>(DataViewSchema.DetachedColumn labelCol)
{
return labelCol.Annotations.GetGetter<VBuffer<T>>(labelCol.Annotations.Schema[AnnotationUtils.Kinds.KeyValues]);
}
private ValueGetter<uint> GetScoreColumnSetId(DataViewSchema schema)
{
int c;
var max = schema.GetMaxAnnotationKind(out c, AnnotationUtils.Kinds.ScoreColumnSetId);
uint id = checked(max + 1);
return
(ref uint dst) => dst = id;
}
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
=> throw new NotImplementedException("This should never be called!");
private Delegate CreateGetter(DataViewRow input, int iinfo, TensorCacher outputCacher)
{
var ch = Host.Start("Make Getter");
if (iinfo == 0)
return MakePredictedLabelGetter(input, ch, outputCacher);
else if (iinfo == 1)
return MakeScoreGetter(input, ch, outputCacher);
else
return MakeBoundingBoxGetter(input, ch, outputCacher);
}
private Delegate MakeScoreGetter(DataViewRow input, IChannel ch, TensorCacher outputCacher)
{
ValueGetter<MLImage> getImage = default;
getImage = input.GetGetter<MLImage>(input.Schema[_parent.Options.ImageColumnName]);
MLImage image = default;
ValueGetter<VBuffer<float>> score = (ref VBuffer<float> dst) =>
{
using var disposeScope = torch.NewDisposeScope();
UpdateCacheIfNeeded(input.Position, outputCacher, ref image, ref getImage);
var editor = VBufferEditor.Create(ref dst, outputCacher.ScoresBuffer.Length);
for (var i = 0; i < outputCacher.ScoresBuffer.Length; i++)
{
editor.Values[i] = outputCacher.ScoresBuffer[i];
}
dst = editor.Commit();
};
return score;
}
private Delegate MakePredictedLabelGetter(DataViewRow input, IChannel ch, TensorCacher outputCacher)
{
ValueGetter<MLImage> getImage = default;
getImage = input.GetGetter<MLImage>(input.Schema[_parent.Options.ImageColumnName]);
MLImage image = default;
ValueGetter<VBuffer<UInt32>> predictedLabel = (ref VBuffer<UInt32> dst) =>
{
using var disposeScope = torch.NewDisposeScope();
UpdateCacheIfNeeded(input.Position, outputCacher, ref image, ref getImage);
var editor = VBufferEditor.Create(ref dst, outputCacher.PredictedLabelsBuffer.Length);
for (var i = 0; i < outputCacher.PredictedLabelsBuffer.Length; i++)
{
editor.Values[i] = outputCacher.PredictedLabelsBuffer[i];
}
dst = editor.Commit();
};
return predictedLabel;
}
private Delegate MakeBoundingBoxGetter(DataViewRow input, IChannel ch, TensorCacher outputCacher)
{
ValueGetter<MLImage> getImage = default;
getImage = input.GetGetter<MLImage>(input.Schema[_parent.Options.ImageColumnName]);
MLImage image = default;
ValueGetter<VBuffer<float>> score = (ref VBuffer<float> dst) =>
{
using var disposeScope = torch.NewDisposeScope();
UpdateCacheIfNeeded(input.Position, outputCacher, ref image, ref getImage);
var editor = VBufferEditor.Create(ref dst, outputCacher.BoxBuffer.Length);
for (var i = 0; i < outputCacher.BoxBuffer.Length; i++)
{
editor.Values[i] = outputCacher.BoxBuffer[i];
}
dst = editor.Commit();
};
return score;
}
public override Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
{
Host.AssertValue(input);
Contracts.Assert(input.Schema == base.InputSchema);
TensorCacher outputCacher = new TensorCacher(_parent.Options.NumberOfClasses);
var ch = Host.Start("Make Getters");
_parent.Model.eval();
int n = OutputColumns.Value.Length;
var result = new Delegate[n];
for (int i = 0; i < n; i++)
{
if (!activeOutput(i))
continue;
result[i] = CreateGetter(input, i, outputCacher);
}
disposer = () =>
{
outputCacher.Dispose();
};
return result;
}
private Tensor PrepInputTensors(ref MLImage image, ValueGetter<MLImage> imageGetter)
{
imageGetter(ref image);
using (var preprocessScope = torch.NewDisposeScope())
{
var midTensor0 = torch.tensor(image.GetBGRPixels, device: _parent.Device);
var midTensor1 = midTensor0.@float();
var midTensor2 = midTensor1.reshape(1, image.Height, image.Width, 3);
var midTensor3 = midTensor2.transpose(0, 3);
var midTensor4 = midTensor3.reshape(3, image.Height, image.Width);
var chunks = midTensor4.chunk(3, 0);
var part = new List<Tensor>();
part.Add(chunks[2]);
part.Add(chunks[1]);
part.Add(chunks[0]);
var midTensor = torch.cat(part, 0);
var reMidTensor = midTensor.reshape(1, 3, image.Height, image.Width);
var padW = 32 - (image.Width % 32);
var padH = 32 - (image.Height % 32);
var transMidTensor = torch.zeros(1, 3, image.Height + padH, image.Width + padW, device: _parent.Device);
transMidTensor[RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..), RangeUtil.ToTensorIndex(..image.Height), RangeUtil.ToTensorIndex(..image.Width)] = reMidTensor / 255.0;
var imageTensor = ObjectDetectionTrainer.Trainer.Normalize(transMidTensor, _parent.Device);
return imageTensor.MoveToOuterDisposeScope();
}
}
private (Tensor, Tensor, Tensor) PrepAndRunModel(Tensor inputTensor)
{
return _parent.Model.forward(inputTensor);
}
private protected class TensorCacher : IDisposable
{
public long Position;
public int MaxLength;
public UInt32[] PredictedLabelsBuffer;
public Single[] ScoresBuffer;
public Single[] BoxBuffer;
public TensorCacher(int maxLength)
{
Position = -1;
MaxLength = maxLength;
PredictedLabelsBuffer = default;
ScoresBuffer = default;
BoxBuffer = default;
}
private bool _isDisposed;
public void Dispose()
{
if (_isDisposed)
return;
_isDisposed = true;
}
}
private protected void UpdateCacheIfNeeded(long position, TensorCacher outputCache, ref MLImage image, ref ValueGetter<MLImage> getImage)
{
if (outputCache.Position != position)
{
var imageTensor = PrepInputTensors(ref image, getImage);
_parent.Model.eval();
(var pred, var score, var box) = PrepAndRunModel(imageTensor);
ImageUtils.Postprocess(imageTensor, pred, score, box, out outputCache.PredictedLabelsBuffer, out outputCache.ScoresBuffer, out outputCache.BoxBuffer, _parent.Options.ScoreThreshold, _parent.Options.IOUThreshold);
pred.Dispose();
score.Dispose();
box.Dispose();
outputCache.Position = position;
}
}
private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
return col => (activeOutput(0) || activeOutput(1) || activeOutput(2)) && _inputColIndices.Any(i => i == col);
}
}
protected virtual void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
}
Model.Dispose();
_disposedValue = true;
}
}
~ObjectDetectionTransformer()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: false);
}
public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
}
|