|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Google.Protobuf;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.TensorFlow;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using Microsoft.ML.Vision;
using Tensorflow;
using Tensorflow.Summaries;
using static Microsoft.ML.Data.TextLoader;
using static Microsoft.ML.TensorFlow.TensorFlowUtils;
using static Tensorflow.Binding;
using Column = Microsoft.ML.Data.TextLoader.Column;
[assembly: LoadableClass(ImageClassificationTrainer.Summary, typeof(ImageClassificationTrainer),
typeof(ImageClassificationTrainer.Options),
new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) },
ImageClassificationTrainer.UserName,
ImageClassificationTrainer.LoadName,
ImageClassificationTrainer.ShortName)]
[assembly: LoadableClass(typeof(ImageClassificationModelParameters), null, typeof(SignatureLoadModel),
"Image classification predictor", ImageClassificationModelParameters.LoaderSignature)]
namespace Microsoft.ML.Vision
{
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> for training a Deep Neural Network(DNN) to classify images.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [ImageClassification](xref:Microsoft.ML.VisionCatalog.ImageClassification(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,System.String,Microsoft.ML.IDataView)).
///
/// ### Input and Output Columns
/// The input label column data must be [key](xref:Microsoft.ML.Data.KeyDataViewType) type and the feature column must be a variable-sized vector of <xref:System.Byte>.
///
/// This trainer outputs the following columns:
///
/// | Output Column Name | Column Type | Description|
/// | -- | -- | -- |
/// | `Score` | Vector of<xref:System.Single> | The scores of all classes.Higher value means higher probability to fall into the associated class. If the i-th element has the largest value, the predicted label index would be i.Note that i is zero-based index. |
/// | `PredictedLabel` | [key](xref:Microsoft.ML.Data.KeyDataViewType) type | The predicted label's index. If its value is i, the actual label would be the i-th category in the key-valued input label type. |
///
/// ### Trainer Characteristics
/// | | |
/// | -- | -- |
/// | Machine learning task | Multiclass classification |
/// | Is normalization required? | No |
/// | Is caching required? | No |
/// | Required NuGet in addition to Microsoft.ML | Microsoft.ML.Vision and SciSharp.TensorFlow.Redist / SciSharp.TensorFlow.Redist-Windows-GPU / SciSharp.TensorFlow.Redist-Linux-GPU |
/// | Exportable to ONNX | No |
///
/// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)]
///
/// ### Training Algorithm Details
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained model such as Resnet50 for the purpose
/// of classifying images. The technique was inspired from [TensorFlow's retrain image classification tutorial](https://www.tensorflow.org/hub/tutorials/image_retraining)
/// ]]>
/// </format>
/// </remarks>
public sealed class ImageClassificationTrainer :
TrainerEstimatorBase<MulticlassPredictionTransformer<ImageClassificationModelParameters>,
ImageClassificationModelParameters>
{
internal const string LoadName = "ImageClassificationTrainer";
internal const string UserName = "Image Classification Trainer";
internal const string ShortName = "IMGCLSS";
internal const string Summary = "Trains a DNN model to classify images.";
/// <summary>
/// Image classification model.
/// </summary>
public enum Architecture
{
ResnetV2101,
InceptionV3,
MobilenetV2,
ResnetV250
};
/// <summary>
/// Dictionary mapping model architecture to model location.
/// </summary>
internal static IReadOnlyDictionary<Architecture, string> ModelFileName = new Dictionary<Architecture, string>
{
{ Architecture.ResnetV2101, @"resnet_v2_101_299.meta" },
{ Architecture.InceptionV3, @"inception_v3.meta" },
{ Architecture.MobilenetV2, @"mobilenet_v2.meta" },
{ Architecture.ResnetV250, @"resnet_v2_50_299.meta" }
};
/// <summary>
/// Dictionary mapping model architecture to image input size supported.
/// </summary>
internal static IReadOnlyDictionary<Architecture, Tuple<int, int>> ImagePreprocessingSize =
new Dictionary<Architecture, Tuple<int, int>>
{
{ Architecture.ResnetV2101, new Tuple<int, int>(299,299) },
{ Architecture.InceptionV3, new Tuple<int, int>(299,299) },
{ Architecture.MobilenetV2, new Tuple<int, int>(224,224) },
{ Architecture.ResnetV250, new Tuple<int, int>(299,299) }
};
/// <summary>
/// Indicates the metric to be monitored to decide Early Stopping criteria.
/// </summary>
public enum EarlyStoppingMetric
{
Accuracy,
Loss
}
/// <summary>
/// DNN training metrics.
/// </summary>
public sealed class TrainMetrics
{
/// <summary>
/// Indicates the dataset on which metrics are being reported.
/// <see cref="ImageClassificationMetrics.Dataset"/>
/// </summary>
public ImageClassificationMetrics.Dataset DatasetUsed { get; set; }
/// <summary>
/// The number of batches processed in an epoch.
/// </summary>
public int BatchProcessedCount { get; set; }
/// <summary>
/// The training epoch index for which this metric is reported.
/// </summary>
public int Epoch { get; set; }
/// <summary>
/// Accuracy of the batch on this <see cref="Epoch"/>. Higher the better.
/// </summary>
public float Accuracy { get; set; }
/// <summary>
/// Cross-Entropy (loss) of the batch on this <see cref="Epoch"/>. Lower
/// the better.
/// </summary>
public float CrossEntropy { get; set; }
/// <summary>
/// Learning Rate used for this <see cref="Epoch"/>. Changes for learning rate scheduling.
/// </summary>
public float LearningRate { get; set; }
/// <summary>
/// String representation of the metrics.
/// </summary>
public override string ToString()
{
if (DatasetUsed == ImageClassificationMetrics.Dataset.Train)
return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " +
$"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}, Learning Rate: {LearningRate,10}";
else
return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " +
$"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}";
}
}
/// <summary>
/// Metrics for image featurization values. The input image is passed through
/// the network and features are extracted from second or last layer to
/// train a custom full connected layer that serves as classifier.
/// </summary>
public sealed class BottleneckMetrics
{
/// <summary>
/// Indicates the dataset on which metrics are being reported.
/// <see cref="ImageClassificationMetrics.Dataset"/>
/// </summary>
public ImageClassificationMetrics.Dataset DatasetUsed { get; set; }
/// <summary>
/// Index of the input image.
/// </summary>
public int Index { get; set; }
/// <summary>
/// String representation of the metrics.
/// </summary>
public override string ToString() => $"Phase: Bottleneck Computation, Dataset used: {DatasetUsed.ToString(),10}, Image Index: {Index,3}";
}
/// <summary>
/// Early Stopping feature stops training when monitored quantity stops improving'.
/// Modeled after https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/
/// tensorflow/python/keras/callbacks.py#L1143
/// </summary>
public sealed class EarlyStopping
{
/// <summary>
/// Best value of metric seen so far.
/// </summary>
private float _bestMetricValue;
/// <summary>
/// Current counter for number of epochs where there has been no improvement.
/// </summary>
private int _wait;
/// <summary>
/// The metric to be monitored (eg Accuracy, Loss).
/// </summary>
private readonly EarlyStoppingMetric _metric;
/// <summary>
/// Minimum change in the monitored quantity to be considered as an improvement.
/// </summary>
public float MinDelta { get; set; }
/// <summary>
/// Number of epochs to wait after no improvement is seen consecutively
/// before stopping the training.
/// </summary>
public int Patience { get; set; }
/// <summary>
/// Whether the monitored quantity is to be increasing (eg. Accuracy, CheckIncreasing = true)
/// or decreasing (eg. Loss, CheckIncreasing = false).
/// </summary>
public bool CheckIncreasing { get; set; }
/// <param name="minDelta"></param>
/// <param name="patience"></param>
/// <param name="metric"></param>
/// <param name="checkIncreasing"></param>
public EarlyStopping(float minDelta = 0.01f, int patience = 20, EarlyStoppingMetric metric = EarlyStoppingMetric.Accuracy, bool checkIncreasing = true)
{
_bestMetricValue = 0.0f;
_wait = 0;
_metric = metric;
MinDelta = Math.Abs(minDelta);
Patience = patience;
CheckIncreasing = checkIncreasing;
//Set the CheckIncreasing according to the metric being monitored
if (metric == EarlyStoppingMetric.Accuracy)
CheckIncreasing = true;
else if (metric == EarlyStoppingMetric.Loss)
{
CheckIncreasing = false;
_bestMetricValue = Single.MaxValue;
}
}
/// <summary>
/// To be called at the end of every epoch to check if training should stop.
/// For increasing metric(eg.: Accuracy), if metric stops increasing, stop training if
/// value of metric doesn't increase within 'patience' number of epochs.
/// For decreasing metric(eg.: Loss), stop training if value of metric doesn't decrease
/// within 'patience' number of epochs.
/// Any change in the value of metric of less than 'minDelta' is not considered a change.
/// </summary>
public bool ShouldStop(TrainMetrics currentMetrics)
{
float currentMetricValue = _metric == EarlyStoppingMetric.Accuracy ? currentMetrics.Accuracy : currentMetrics.CrossEntropy;
if (CheckIncreasing)
{
if ((currentMetricValue - _bestMetricValue) < MinDelta)
{
_wait += 1;
if (_wait >= Patience)
return true;
}
else
{
_wait = 0;
_bestMetricValue = currentMetricValue;
}
}
else
{
if ((_bestMetricValue - currentMetricValue) < MinDelta)
{
_wait += 1;
if (_wait >= Patience)
return true;
}
else
{
_wait = 0;
_bestMetricValue = currentMetricValue;
}
}
return false;
}
}
/// <summary>
/// Metrics for image classification bottleneck phase and training.
/// Train metrics may be null when bottleneck phase is running, so have check!
/// </summary>
public sealed class ImageClassificationMetrics
{
/// <summary>
/// Indicates the kind of the dataset of which metric is reported.
/// </summary>
public enum Dataset
{
Train,
Validation
}
/// <summary>
/// Contains train time metrics.
/// </summary>
public TrainMetrics Train { get; set; }
/// <summary>
/// Contains pre-train time metrics. These contains metrics on image
/// featurization.
/// </summary>
public BottleneckMetrics Bottleneck { get; set; }
/// <summary>
/// String representation of the metrics.
/// </summary>
public override string ToString() => Train != null ? Train.ToString() : Bottleneck.ToString();
}
/// <summary>
/// Options class for <see cref="ImageClassificationTrainer"/>.
/// </summary>
public sealed class Options : TrainerInputBaseWithLabel
{
/// <summary>
/// Number of samples to use for mini-batch training. The default value for BatchSize is 10.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of samples to use for mini-batch training.", SortOrder = 9)]
public int BatchSize = 10;
/// <summary>
/// Number of training iterations. The default value for Epoch is 200.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations.", SortOrder = 10)]
public int Epoch = 200;
/// <summary>
/// Learning rate to use during optimization. The default value for Learning Rate is 0.01.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)]
public float LearningRate = 0.01f;
/// <summary>
/// Early stopping technique parameters to be used to terminate training when training metric stops improving. By default EarlyStopping is turned on and the monitoring metric is Accuracy.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)]
public EarlyStopping EarlyStoppingCriteria = new EarlyStopping();
/// <summary>
/// Specifies the model architecture to be used in the case of image classification training using transfer learning. The default Architecture is Resnet_v2_50.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Model architecture to be used in transfer learning for image classification.", SortOrder = 15)]
public Architecture Arch = Architecture.ResnetV250;
/// <summary>
/// Name of the tensor that will contain the output scores of the last layer when transfer learning is done. The default tensor name is "Score".
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Softmax tensor of the last layer in transfer learning.", SortOrder = 15)]
public string ScoreColumnName = "Score";
/// <summary>
/// Name of the tensor that will contain the predicted label from output scores of the last layer when transfer learning is done. The default tensor name is "PredictedLabel".
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Argmax tensor of the last layer in transfer learning.", SortOrder = 15)]
public string PredictedLabelColumnName = "PredictedLabel";
/// <summary>
/// Final model and checkpoint files/folder prefix for storing graph files. The default prefix is "custom_retrained_model_based_on_".
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Final model and checkpoint files/folder prefix for storing graph files.", SortOrder = 15)]
public string FinalModelPrefix = "custom_retrained_model_based_on_";
/// <summary>
/// Callback to report statistics on accuracy/cross entropy during training phase. Metrics Callback is set to null by default.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Callback to report metrics during training and validation phase.", SortOrder = 15)]
public Action<ImageClassificationMetrics> MetricsCallback = null;
/// <summary>
/// Indicates the path where the image bottleneck cache files and trained model are saved, default is a new temporary directory.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory.", SortOrder = 15)]
public string WorkspacePath = null;
/// <summary>
/// Indicates to evaluate the model on train set after every epoch. Test on trainset is set to true by default.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to evaluate the model on train set after every epoch.", SortOrder = 15)]
public bool TestOnTrainSet = true;
/// <summary>
/// Indicates to not re-compute cached bottleneck trainset values if already available in the bin folder. This parameter is set to false by default.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute trained cached bottleneck values if already available in the bin folder.", SortOrder = 15)]
public bool ReuseTrainSetBottleneckCachedValues = false;
/// <summary>
/// Indicates to not re-compute cached bottleneck validationset values if already available in the bin folder. This parameter is set to false by default.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute validataionset cached bottleneck validationset values if already available in the bin folder.", SortOrder = 15)]
public bool ReuseValidationSetBottleneckCachedValues = false;
/// <summary>
/// Validation set.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Validation set.", SortOrder = 15)]
public IDataView ValidationSet;
/// <summary>
/// When validation set is not passed then a fraction of train set is used as validation. To disable this
/// behavior set <see cref="ValidationSetFraction"/> to null. Accepts value between 0 and 1.0, default
/// value is 0.1 or 10% of the trainset.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Validation fraction.", SortOrder = 15)]
public float? ValidationSetFraction = 0.1f;
/// <summary>
/// Indicates the file name within the workspace to store trainset bottleneck values for caching, default file name is "trainSetBottleneckFile.csv".
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file name to store trainset bottleneck values for caching.", SortOrder = 15)]
public string TrainSetBottleneckCachedValuesFileName = "trainSetBottleneckFile.csv";
/// <summary>
/// Indicates the file name within the workspace to store validationset bottleneck values for caching, default file name is "validationSetBottleneckFile.csv".
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file name to store validationset bottleneck values for caching.", SortOrder = 15)]
public string ValidationSetBottleneckCachedValuesFileName = "validationSetBottleneckFile.csv";
/// <summary>
/// A class that performs learning rate scheduling. The default learning rate scheduler is exponential learning rate decay.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)]
public LearningRateScheduler LearningRateScheduler = new ExponentialLRDecay();
}
/// <summary> Return the type of prediction task.</summary>
private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
/// <summary>
/// Auxiliary information about the trainer in terms of its capabilities
/// and requirements.
/// </summary>
public override TrainerInfo Info => _info;
private readonly Options _options;
private Session _session;
private Operation _trainStep;
private Tensor _bottleneckTensor;
private Tensor _learningRateInput;
private Tensor _softMaxTensor;
private Tensor _crossEntropy;
private Tensor _labelTensor;
private Tensor _evaluationStep;
private Tensor _prediction;
private Tensor _bottleneckInput;
private Tensor _jpegData;
private Tensor _resizedImage;
private string _jpegDataTensorName;
private string _resizedImageTensorName;
private readonly string _inputTensorName;
private string _softmaxTensorName;
private readonly string _checkpointPath;
private readonly string _bottleneckOperationName;
private readonly bool _useLRScheduling;
private readonly bool _cleanupWorkspace;
private int _classCount;
private Graph Graph => _session.graph;
private readonly string _resourcePath;
private readonly string _sizeFile;
/// <summary>
/// Initializes a new instance of <see cref="ImageClassificationTrainer"/>
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="scoreColumn">The name of score column.</param>
/// <param name="predictedLabelColumn">The name of the predicted label column.</param>
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
internal ImageClassificationTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string scoreColumn = DefaultColumnNames.Score,
string predictedLabelColumn = DefaultColumnNames.PredictedLabel,
IDataView validationSet = null)
: this(env, new Options()
{
FeatureColumnName = featureColumn,
LabelColumnName = labelColumn,
ScoreColumnName = scoreColumn,
PredictedLabelColumnName = predictedLabelColumn,
ValidationSet = validationSet
})
{
}
/// <summary>
/// Initializes a new instance of <see cref="ImageClassificationTrainer"/>
/// </summary>
internal ImageClassificationTrainer(IHostEnvironment env, Options options)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadName),
new SchemaShape.Column(options.FeatureColumnName, SchemaShape.Column.VectorKind.VariableVector,
NumberDataViewType.Byte, false),
TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName))
{
Host.CheckValue(options, nameof(options));
Host.CheckNonEmpty(options.FeatureColumnName, nameof(options.FeatureColumnName));
Host.CheckNonEmpty(options.LabelColumnName, nameof(options.LabelColumnName));
Host.CheckNonEmpty(options.ScoreColumnName, nameof(options.ScoreColumnName));
Host.CheckNonEmpty(options.PredictedLabelColumnName, nameof(options.PredictedLabelColumnName));
tf.compat.v1.disable_eager_execution();
_resourcePath = Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, "MLNET");
if (string.IsNullOrEmpty(options.WorkspacePath))
{
options.WorkspacePath = GetTemporaryDirectory(env);
_cleanupWorkspace = true;
}
if (!Directory.Exists(_resourcePath))
{
Directory.CreateDirectory(_resourcePath);
}
if (string.IsNullOrEmpty(options.TrainSetBottleneckCachedValuesFileName))
{
//If the user decided to set to null reset back to default value
options.TrainSetBottleneckCachedValuesFileName = _options.TrainSetBottleneckCachedValuesFileName;
}
if (string.IsNullOrEmpty(options.ValidationSetBottleneckCachedValuesFileName))
{
//If the user decided to set to null reset back to default value
options.ValidationSetBottleneckCachedValuesFileName = _options.ValidationSetBottleneckCachedValuesFileName;
}
if (options.MetricsCallback == null)
{
var logger = Host.Start(nameof(ImageClassificationTrainer));
options.MetricsCallback = (ImageClassificationMetrics metric) => { logger.Trace(metric.ToString()); };
}
_options = options;
_useLRScheduling = _options.LearningRateScheduler != null;
_checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix +
ModelFileName[_options.Arch]);
_sizeFile = Path.Combine(_options.WorkspacePath, "TrainingSetSize.txt");
// Configure bottleneck tensor based on the model.
var arch = _options.Arch;
if (arch == Architecture.ResnetV2101)
{
_bottleneckOperationName = "resnet_v2_101/SpatialSqueeze";
_inputTensorName = "input";
}
else if (arch == Architecture.InceptionV3)
{
_bottleneckOperationName = "InceptionV3/Logits/SpatialSqueeze";
_inputTensorName = "input";
}
else if (arch == Architecture.MobilenetV2)
{
_bottleneckOperationName = "import/MobilenetV2/Logits/Squeeze";
_inputTensorName = "import/input";
}
else if (arch == Architecture.ResnetV250)
{
_bottleneckOperationName = "resnet_v2_50/SpatialSqueeze";
_inputTensorName = "input";
}
}
private void InitializeTrainingGraph(IDataView input)
{
var labelColumn = input.Schema.GetColumnOrNull(_options.LabelColumnName).Value;
var labelType = labelColumn.Type;
var labelCount = labelType.GetKeyCount();
if (labelCount <= 0)
{
throw Host.ExceptSchemaMismatch(nameof(input.Schema), "label", (string)labelColumn.Name, "Key",
(string)labelType.ToString());
}
var msg = $"Only one class found in the {_options.LabelColumnName} column. To build a multiclass classification model, the number of classes needs to be 2 or greater";
Contracts.CheckParam(labelCount > 1, nameof(labelCount), msg);
_classCount = (int)labelCount;
var imageSize = ImagePreprocessingSize[_options.Arch];
_session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch).Session;
_session.graph.as_default();
(_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3);
_jpegDataTensorName = _jpegData.name;
_resizedImageTensorName = _resizedImage.name;
// Add transfer learning layer.
AddTransferLearningLayer(_options.LabelColumnName, _options.ScoreColumnName, _options.LearningRate,
_useLRScheduling, _classCount);
// Initialize the variables.
new Runner(_session, operations: new IntPtr[] { tf.global_variables_initializer() }).Run();
// Add evaluation layer.
(_evaluationStep, _) = AddEvaluationStep(_softMaxTensor, _labelTensor);
_softmaxTensorName = _softMaxTensor.name;
}
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
bool success = inputSchema.TryFindColumn(_options.LabelColumnName, out _);
Contracts.Assert(success);
var metadata = new List<SchemaShape.Column>();
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
TextDataViewType.Instance, false));
return new[]
{
new SchemaShape.Column(_options.ScoreColumnName, SchemaShape.Column.VectorKind.Vector,
NumberDataViewType.Single, false),
new SchemaShape.Column(_options.PredictedLabelColumnName, SchemaShape.Column.VectorKind.Scalar,
NumberDataViewType.UInt32, true, new SchemaShape(metadata.ToArray()))
};
}
private protected override MulticlassPredictionTransformer<ImageClassificationModelParameters> MakeTransformer(
ImageClassificationModelParameters model, DataViewSchema trainSchema)
=> new MulticlassPredictionTransformer<ImageClassificationModelParameters>(Host, model, trainSchema,
FeatureColumn.Name, LabelColumn.Name, _options.ScoreColumnName, _options.PredictedLabelColumnName);
private protected override ImageClassificationModelParameters TrainModelCore(TrainContext trainContext)
{
// Workspace directory is cleaned after training run. However, the pipeline can be re-used by calling
// fit() again after transform(), in which case we must ensure workspace directory exists. This scenario
// is typical in the case of cross-validation.
if (!Directory.Exists(_options.WorkspacePath))
{
Directory.CreateDirectory(_options.WorkspacePath);
}
InitializeTrainingGraph(trainContext.TrainingSet.Data);
CheckTrainingParameters(_options);
var validationSet = trainContext.ValidationSet?.Data ?? _options.ValidationSet;
var imageProcessor = new ImageProcessor(_session, _jpegDataTensorName, _resizedImageTensorName);
string trainSetBottleneckCachedValuesFilePath = Path.Combine(_options.WorkspacePath,
_options.TrainSetBottleneckCachedValuesFileName);
string validationSetBottleneckCachedValuesFilePath = Path.Combine(_options.WorkspacePath,
_options.ValidationSetBottleneckCachedValuesFileName);
bool needValidationSet = _options.EarlyStoppingCriteria != null || _options.MetricsCallback != null;
bool validationSetPresent = _options.ReuseValidationSetBottleneckCachedValues &&
File.Exists(validationSetBottleneckCachedValuesFilePath + "_features.bin") &&
File.Exists(validationSetBottleneckCachedValuesFilePath + "_labels.bin");
bool generateValidationSet = needValidationSet && !validationSetPresent;
if (generateValidationSet && _options.ValidationSet != null)
{
CacheFeaturizedImagesToDisk(validationSet, _options.LabelColumnName,
_options.FeatureColumnName, imageProcessor, _inputTensorName, _bottleneckTensor.name,
validationSetBottleneckCachedValuesFilePath,
ImageClassificationMetrics.Dataset.Validation, _options.MetricsCallback);
generateValidationSet = false;
validationSetPresent = true;
}
if (!_options.ReuseTrainSetBottleneckCachedValues ||
!(File.Exists(trainSetBottleneckCachedValuesFilePath + "_features.bin") &&
File.Exists(trainSetBottleneckCachedValuesFilePath + "_labels.bin")))
{
CacheFeaturizedImagesToDisk(trainContext.TrainingSet.Data, _options.LabelColumnName,
_options.FeatureColumnName, imageProcessor,
_inputTensorName, _bottleneckTensor.name, trainSetBottleneckCachedValuesFilePath,
ImageClassificationMetrics.Dataset.Train, _options.MetricsCallback,
generateValidationSet ? _options.ValidationSetFraction : null);
validationSetPresent = validationSetPresent ||
(generateValidationSet && _options.ValidationSetFraction.HasValue);
generateValidationSet = needValidationSet && !validationSetPresent;
}
if (generateValidationSet && _options.ReuseTrainSetBottleneckCachedValues &&
!_options.ReuseValidationSetBottleneckCachedValues)
{
// Not sure if it makes sense to support this scenario.
}
Contracts.Assert(!generateValidationSet, "Validation set needed but cannot generate.");
TrainAndEvaluateClassificationLayer(trainSetBottleneckCachedValuesFilePath,
validationSetPresent && (_options.EarlyStoppingCriteria != null || _options.MetricsCallback != null) ?
validationSetBottleneckCachedValuesFilePath : null);
// Leave the ownership of _session so that it is not disposed/closed when this object goes out of scope
// since it will be used by ImageClassificationModelParameters class (new owner that will take care of
// disposing).
var session = _session;
_session = null;
return new ImageClassificationModelParameters(Host, session, _classCount, _jpegDataTensorName,
_resizedImageTensorName, _inputTensorName, _softmaxTensorName);
}
private void CheckTrainingParameters(Options options)
{
Host.CheckNonWhiteSpace(options.LabelColumnName, nameof(options.LabelColumnName));
if (_session.graph.OperationByName(_labelTensor.name.Split(':')[0]) == null)
{
throw Host.ExceptParam(nameof(_labelTensor.name), $"'{_labelTensor.name}' does not" +
$"exist in the model");
}
if (options.EarlyStoppingCriteria != null && options.ValidationSet == null &&
options.TestOnTrainSet == false)
{
throw Host.ExceptParam(nameof(options.EarlyStoppingCriteria), $"Early stopping enabled but unable to" +
$"find a validation set and/or train set testing disabled. Please disable early stopping " +
$"or either provide a validation set or enable train set training.");
}
}
private (Tensor, Tensor) AddJpegDecoding(int height, int width, int depth)
{
// height, width, depth
var inputDim = (height, width, depth);
var jpegData = tf.placeholder(tf.@string, name: "DecodeJPGInput");
var decodedImage = tf.image.decode_jpeg(jpegData, channels: inputDim.Item3);
// Convert from full range of uint8 to range [0,1] of float32.
var decodedImageAsFloat = tf.image.convert_image_dtype(decodedImage, tf.float32);
var decodedImage4d = tf.expand_dims(decodedImageAsFloat, 0);
var resizeShape = tf.stack(new int[] { inputDim.Item1, inputDim.Item2 });
var resizeShapeAsInt = tf.cast(resizeShape, dtype: tf.int32);
var resizedImage = tf.image.resize_bilinear(decodedImage4d, resizeShapeAsInt, false, name: "ResizeTensor");
return (jpegData, resizedImage);
}
private static Tensor EncodeByteAsString(VBuffer<byte> buffer)
{
int length = buffer.Length;
var size = c_api.TF_StringEncodedSize((ulong)length);
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, Array.Empty<long>(), 0, ((ulong)size + 8));
IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0);
var status = new Status();
unsafe
{
fixed (byte* src = buffer.GetValues())
c_api.TF_StringEncode(src, (ulong)length, (byte*)(tensor + sizeof(Int64)), size, status.Handle);
}
status.Check(true);
status.Dispose();
return new Tensor(handle);
}
internal sealed class ImageProcessor
{
private readonly Runner _imagePreprocessingRunner;
public ImageProcessor(Session session, string jpegDataTensorName, string resizeImageTensorName)
{
_imagePreprocessingRunner = new Runner(session, new[] { jpegDataTensorName },
new[] { resizeImageTensorName });
}
public Tensor ProcessImage(in VBuffer<byte> imageBuffer)
{
using (var imageTensor = EncodeByteAsString(imageBuffer))
{
try
{
return _imagePreprocessingRunner.AddInput(imageTensor, 0).Run()[0];
}
catch (TensorflowException e)
{
//catch the exception for images of unknown format
if (e.HResult == -2146233088)
return null;
else
throw;
}
}
}
}
private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName, string imageColumnName,
ImageProcessor imageProcessor, string inputTensorName, string outputTensorName, string cacheFilePath,
ImageClassificationMetrics.Dataset dataset, Action<ImageClassificationMetrics> metricsCallback,
float? validationFraction = null)
{
var labelColumn = input.Schema[labelColumnName];
if (labelColumn.Type.RawType != typeof(uint))
{
throw Host.ExceptSchemaMismatch(nameof(labelColumn), "Label",
labelColumnName, typeof(uint).ToString(),
labelColumn.Type.RawType.ToString());
}
var imageColumn = input.Schema[imageColumnName];
Runner runner = new Runner(_session, new[] { inputTensorName }, new[] { outputTensorName });
List<(long, float[])> featurizedImages = new List<(long, float[])>();
using (var cursor = input.GetRowCursor(
input.Schema.Where(c => c.Index == labelColumn.Index || c.Index == imageColumn.Index)))
{
var labelGetter = cursor.GetGetter<uint>(labelColumn);
var imageGetter = cursor.GetGetter<VBuffer<byte>>(imageColumn);
uint label = uint.MaxValue;
VBuffer<byte> image = default;
ImageClassificationMetrics metrics = new ImageClassificationMetrics();
metrics.Bottleneck = new BottleneckMetrics();
metrics.Bottleneck.DatasetUsed = dataset;
while (cursor.MoveNext())
{
CheckAlive();
labelGetter(ref label);
imageGetter(ref image);
if (image.Length <= 0)
continue; //Empty Image
var imageTensor = imageProcessor.ProcessImage(image);
if (imageTensor != null)
{
runner.AddInput(imageTensor, 0);
var featurizedImage = runner.Run()[0];
featurizedImages.Add((label - 1, featurizedImage.ToArray<float>()));
featurizedImage.Dispose();
imageTensor.Dispose();
metrics.Bottleneck.Index++;
metricsCallback?.Invoke(metrics);
}
}
featurizedImages = featurizedImages.OrderBy(x => Host.Rand.Next(0, metrics.Bottleneck.Index)).ToList();
int featureLength = featurizedImages.Count > 0 ? featurizedImages[0].Item2.Length : 0;
int validationSetCount = 0;
if (validationFraction.HasValue)
{
Contracts.Assert(validationFraction >= 0 && validationFraction <= 1);
validationSetCount = (int)(metrics.Bottleneck.Index * validationFraction);
CreateFeaturizedCacheFile(
Path.Combine(_options.WorkspacePath, _options.ValidationSetBottleneckCachedValuesFileName),
validationSetCount, featureLength, featurizedImages.Take(validationSetCount));
}
CreateFeaturizedCacheFile(cacheFilePath, metrics.Bottleneck.Index - validationSetCount, featureLength,
featurizedImages.Skip(validationSetCount));
}
}
private void CreateFeaturizedCacheFile(string cacheFilePath, int examples, int featureLength,
IEnumerable<(long, float[])> featurizedImages)
{
Contracts.Assert(examples == featurizedImages.Count());
Contracts.Assert(featurizedImages.All(x => x.Item2.Length == featureLength));
using Stream featuresWriter = File.Open(cacheFilePath + "_features.bin", FileMode.Create);
using Stream labelWriter = File.Open(cacheFilePath + "_labels.bin", FileMode.Create);
using TextWriter writer = File.CreateText(cacheFilePath);
featuresWriter.Write(BitConverter.GetBytes(examples), 0, sizeof(int));
featuresWriter.Write(BitConverter.GetBytes(featureLength), 0, sizeof(int));
long[] labels = new long[1];
var labelsSpan = MemoryMarshal.Cast<long, byte>(labels);
foreach (var row in featurizedImages)
{
CheckAlive();
writer.WriteLine(row.Item1 + "," + string.Join(",", row.Item2));
labels[0] = row.Item1;
for (int index = 0; index < sizeof(long); index++)
{
labelWriter.WriteByte(labelsSpan[index]);
}
var featureSpan = MemoryMarshal.Cast<float, byte>(row.Item2);
for (int index = 0; index < featureLength * sizeof(float); index++)
{
featuresWriter.WriteByte(featureSpan[index]);
}
}
}
private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
string validationSetBottleneckFilePath)
{
Contracts.Assert(validationSetBottleneckFilePath == null ||
(File.Exists(validationSetBottleneckFilePath + "_labels.bin") &&
File.Exists(validationSetBottleneckFilePath + "_features.bin")));
Contracts.Assert(trainBottleneckFilePath != null &&
File.Exists(trainBottleneckFilePath + "_labels.bin") &&
File.Exists(trainBottleneckFilePath + "_features.bin"));
bool validationNeeded = validationSetBottleneckFilePath != null;
Contracts.Assert(_options.EarlyStoppingCriteria == null || validationNeeded);
using (Stream trainSetLabelReader = File.Open(trainBottleneckFilePath + "_labels.bin", FileMode.Open))
using (Stream trainSetFeatureReader = File.Open(trainBottleneckFilePath + "_features.bin", FileMode.Open))
{
Stream validationSetLabelReader = validationNeeded ?
File.Open(validationSetBottleneckFilePath + "_labels.bin", FileMode.Open) : null;
Stream validationSetFeatureReader = validationNeeded ?
File.Open(validationSetBottleneckFilePath + "_features.bin", FileMode.Open) : null;
int batchSize = _options.BatchSize;
int epochs = _options.Epoch;
float learningRate = _options.LearningRate;
Action<ImageClassificationMetrics> statisticsCallback = _options.MetricsCallback;
Runner runner = null;
Runner validationEvalRunner = null;
List<string> runnerInputTensorNames = new List<string>();
List<string> runnerOutputTensorNames = new List<string>();
runnerInputTensorNames.Add(_bottleneckInput.name);
runnerInputTensorNames.Add(_labelTensor.name);
if (_options.LearningRateScheduler != null)
runnerInputTensorNames.Add(_learningRateInput.name);
if (statisticsCallback != null && _options.TestOnTrainSet)
{
runnerOutputTensorNames.Add(_evaluationStep.name);
runnerOutputTensorNames.Add(_crossEntropy.name);
}
if (validationNeeded)
{
validationEvalRunner = new Runner(_session, new[] { _bottleneckInput.name, _labelTensor.name },
new[] { _evaluationStep.name, _crossEntropy.name });
}
runner = new Runner(_session, runnerInputTensorNames.ToArray(),
runnerOutputTensorNames.Count() > 0 ? runnerOutputTensorNames.ToArray() : null,
new[] { _trainStep.name });
Saver trainSaver = null;
FileWriter trainWriter = null;
Tensor merged = tf.summary.merge_all();
trainWriter = tf.summary.FileWriter(Path.Combine(_options.WorkspacePath, "train"),
_session.graph);
trainSaver = tf.train.Saver();
trainSaver.save(_session, _checkpointPath);
ImageClassificationMetrics metrics = new ImageClassificationMetrics();
metrics.Train = new TrainMetrics();
float accuracy = 0;
float crossentropy = 0;
var labelTensorShape = _labelTensor.TensorShape.dims.Select(x => (long)x).ToArray();
var featureTensorShape = _bottleneckInput.TensorShape.dims.Select(x => (long)x).ToArray();
byte[] buffer = new byte[sizeof(int)];
trainSetFeatureReader.ReadExactly(buffer, 0, 4);
int trainingExamples = BitConverter.ToInt32(buffer, 0);
trainSetFeatureReader.ReadExactly(buffer, 0, 4);
int featureFileRecordSize = sizeof(float) * BitConverter.ToInt32(buffer, 0);
const int featureFileStartOffset = sizeof(int) * 2;
var labelBufferSizeInBytes = sizeof(long) * batchSize;
var featureBufferSizeInBytes = featureFileRecordSize * batchSize;
byte[] featuresBuffer = new byte[featureBufferSizeInBytes];
byte[] labelBuffer = new byte[labelBufferSizeInBytes];
var featureBufferHandle = GCHandle.Alloc(featuresBuffer, GCHandleType.Pinned);
IntPtr featureBufferPtr = featureBufferHandle.AddrOfPinnedObject();
var labelBufferHandle = GCHandle.Alloc(labelBuffer, GCHandleType.Pinned);
IntPtr labelBufferPtr = labelBufferHandle.AddrOfPinnedObject();
DnnTrainState trainState = new DnnTrainState
{
BatchSize = _options.BatchSize,
BatchesPerEpoch = trainingExamples / _options.BatchSize
};
for (int epoch = 0; epoch < epochs; epoch += 1)
{
CheckAlive();
// Train.
TrainAndEvaluateClassificationLayerCore(epoch, learningRate, featureFileStartOffset,
metrics, labelTensorShape, featureTensorShape, batchSize,
trainSetLabelReader, trainSetFeatureReader, labelBuffer, featuresBuffer,
labelBufferSizeInBytes, featureBufferSizeInBytes, featureFileRecordSize,
_options.LearningRateScheduler, trainState, runner, featureBufferPtr, labelBufferPtr,
(outputTensors, metrics) =>
{
if (_options.TestOnTrainSet && statisticsCallback != null)
{
outputTensors[0].ToScalar(ref accuracy);
outputTensors[1].ToScalar(ref crossentropy);
metrics.Train.Accuracy += accuracy;
metrics.Train.CrossEntropy += crossentropy;
outputTensors[0].Dispose();
outputTensors[1].Dispose();
}
});
if (_options.TestOnTrainSet && statisticsCallback != null)
{
metrics.Train.Epoch = epoch;
metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount;
metrics.Train.CrossEntropy /= metrics.Train.BatchProcessedCount;
metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Train;
statisticsCallback(metrics);
}
if (!validationNeeded)
continue;
// Evaluate.
TrainAndEvaluateClassificationLayerCore(epoch, learningRate, featureFileStartOffset,
metrics, labelTensorShape, featureTensorShape, batchSize,
validationSetLabelReader, validationSetFeatureReader, labelBuffer, featuresBuffer,
labelBufferSizeInBytes, featureBufferSizeInBytes, featureFileRecordSize, null,
trainState, validationEvalRunner, featureBufferPtr, labelBufferPtr,
(outputTensors, metrics) =>
{
outputTensors[0].ToScalar(ref accuracy);
outputTensors[1].ToScalar(ref crossentropy);
metrics.Train.Accuracy += accuracy;
metrics.Train.CrossEntropy += crossentropy;
outputTensors[0].Dispose();
outputTensors[1].Dispose();
});
if (statisticsCallback != null)
{
metrics.Train.Epoch = epoch;
metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount;
metrics.Train.CrossEntropy /= metrics.Train.BatchProcessedCount;
metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Validation;
statisticsCallback(metrics);
}
//Early stopping check
if (_options.EarlyStoppingCriteria != null)
{
if (_options.EarlyStoppingCriteria.ShouldStop(metrics.Train))
break;
}
}
trainSaver.save(_session, _checkpointPath);
validationSetLabelReader?.Dispose();
validationSetFeatureReader?.Dispose();
featureBufferHandle.Free();
labelBufferHandle.Free();
}
UpdateTransferLearningModelOnDisk(_classCount);
TryCleanupTemporaryWorkspace();
}
private void TrainAndEvaluateClassificationLayerCore(int epoch, float learningRate,
int featureFileStartOffset, ImageClassificationMetrics metrics,
long[] labelTensorShape, long[] featureTensorShape, int batchSize, Stream trainSetLabelReader,
Stream trainSetFeatureReader, byte[] labelBufferBytes, byte[] featuresBufferBytes,
int labelBufferSizeInBytes, int featureBufferSizeInBytes, int featureFileRecordSize,
LearningRateScheduler learningRateScheduler, DnnTrainState trainState, Runner runner,
IntPtr featureBufferPtr, IntPtr labelBufferPtr, Action<Tensor[], ImageClassificationMetrics> metricsAggregator)
{
int labelFileBytesRead;
int featuresFileBytesRead;
labelTensorShape[0] = featureTensorShape[0] = batchSize;
metrics.Train.Accuracy = 0;
metrics.Train.CrossEntropy = 0;
metrics.Train.BatchProcessedCount = 0;
metrics.Train.LearningRate = learningRate;
trainState.CurrentBatchIndex = 0;
trainState.CurrentEpoch = epoch;
trainSetLabelReader.Seek(0, SeekOrigin.Begin);
trainSetFeatureReader.Seek(featureFileStartOffset, SeekOrigin.Begin);
labelTensorShape[0] = featureTensorShape[0] = batchSize;
while ((labelFileBytesRead = trainSetLabelReader.TryReadBlock(labelBufferBytes, 0, labelBufferSizeInBytes)) > 0 &&
(featuresFileBytesRead = trainSetFeatureReader.TryReadBlock(featuresBufferBytes, 0, featureBufferSizeInBytes)) > 0)
{
Contracts.Assert(labelFileBytesRead <= labelBufferSizeInBytes);
Contracts.Assert(featuresFileBytesRead <= featureBufferSizeInBytes);
Contracts.Assert(labelFileBytesRead % sizeof(long) == 0);
Contracts.Assert(featuresFileBytesRead % featureFileRecordSize == 0);
Contracts.Assert(labelFileBytesRead / sizeof(long) == featuresFileBytesRead / featureFileRecordSize);
if (labelFileBytesRead < labelBufferSizeInBytes)
{
featureTensorShape[0] = featuresFileBytesRead / featureFileRecordSize;
labelTensorShape[0] = labelFileBytesRead / sizeof(long);
}
Contracts.Assert(featureTensorShape[0] <= featuresBufferBytes.Length / featureFileRecordSize);
Contracts.Assert(labelTensorShape[0] <= labelBufferBytes.Length / sizeof(long));
if (learningRateScheduler != null)
{
// Add learning rate as a placeholder only when learning rate scheduling is used.
metrics.Train.LearningRate = learningRateScheduler.GetLearningRate(trainState);
runner.AddInput(new Tensor(metrics.Train.LearningRate, TF_DataType.TF_FLOAT), 2);
}
var outputTensors = runner.AddInput(new Tensor(featureBufferPtr, featureTensorShape, TF_DataType.TF_FLOAT, featuresFileBytesRead), 0)
.AddInput(new Tensor(labelBufferPtr, labelTensorShape, TF_DataType.TF_INT64, labelFileBytesRead), 1)
.Run();
metrics.Train.BatchProcessedCount += 1;
metricsAggregator(outputTensors, metrics);
trainState.CurrentBatchIndex += 1;
}
}
private void CheckAlive()
{
try
{
Host.CheckAlive();
}
catch (OperationCanceledException)
{
TryCleanupTemporaryWorkspace();
throw;
}
}
private void TryCleanupTemporaryWorkspace()
{
if (_cleanupWorkspace && Directory.Exists(_options.WorkspacePath))
{
try
{
Directory.Delete(_options.WorkspacePath, true);
}
catch (Exception)
{
//We do not want to stop pipeline due to failed cleanup.
}
}
}
private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(int classCount)
{
var evalGraph = LoadMetaGraph(Path.Combine(_resourcePath, ModelFileName[_options.Arch]));
var evalSess = tf.Session(graph: evalGraph);
Tensor evaluationStep = null;
Tensor prediction = null;
Tensor bottleneckTensor = evalGraph.OperationByName(_bottleneckOperationName);
evalGraph.as_default();
var (_, _, groundTruthInput, finalTensor) = AddFinalRetrainOps(classCount, _options.LabelColumnName,
_options.ScoreColumnName, bottleneckTensor, false, (_options.LearningRateScheduler == null ? false : true), _options.LearningRate);
tf.train.Saver().restore(evalSess, _checkpointPath);
(evaluationStep, prediction) = AddEvaluationStep(finalTensor, groundTruthInput);
var imageSize = ImagePreprocessingSize[_options.Arch];
(_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3);
return (evalSess, _labelTensor, evaluationStep, prediction);
}
private (Tensor, Tensor) AddEvaluationStep(Tensor resultTensor, Tensor groundTruthTensor)
{
Tensor evaluationStep = null;
Tensor correctPrediction = null;
tf_with(tf.name_scope("accuracy"), scope =>
{
tf_with(tf.name_scope("correct_prediction"), delegate
{
_prediction = tf.argmax(resultTensor, 1);
correctPrediction = tf.equal(_prediction, groundTruthTensor);
});
tf_with(tf.name_scope("accuracy"), delegate
{
evaluationStep = tf.reduce_mean(tf.cast(correctPrediction, tf.float32));
});
});
tf.summary.scalar("accuracy", evaluationStep);
return (evaluationStep, _prediction);
}
private void UpdateTransferLearningModelOnDisk(int classCount)
{
var (sess, _, _, _) = BuildEvaluationSession(classCount);
var graph = sess.graph;
var outputGraphDef = tf.graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), new string[] { _softMaxTensor.name.Split(':')[0],
_prediction.name.Split(':')[0], _jpegData.name.Split(':')[0], _resizedImage.name.Split(':')[0] });
string frozenModelPath = _checkpointPath + ".pb";
File.WriteAllBytes(_checkpointPath + ".pb", outputGraphDef.ToByteArray());
_session.graph.Dispose();
_session.Dispose();
_session = LoadTFSessionByModelFilePath(Host, frozenModelPath, false);
sess.graph.Dispose();
sess.Dispose();
}
private void VariableSummaries(ResourceVariable var)
{
tf_with(tf.name_scope("summaries"), delegate
{
var mean = tf.reduce_mean(var);
tf.summary.scalar("mean", mean);
Tensor stddev = null;
tf_with(tf.name_scope("stddev"), delegate
{
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)));
});
tf.summary.scalar("stddev", stddev);
tf.summary.scalar("max", tf.reduce_max(var));
tf.summary.scalar("min", tf.reduce_min(var));
tf.summary.histogram("histogram", var);
});
}
private (Operation, Tensor, Tensor, Tensor) AddFinalRetrainOps(int classCount, string labelColumn,
string scoreColumnName, Tensor bottleneckTensor, bool isTraining, bool useLearningRateScheduler,
float learningRate)
{
var bottleneckTensorDims = bottleneckTensor.TensorShape.dims;
var (batch_size, bottleneck_tensor_size) = (bottleneckTensorDims[0], bottleneckTensorDims[1]);
tf_with(tf.name_scope("input"), scope =>
{
if (isTraining)
{
_bottleneckInput = tf.placeholder_with_default(
bottleneckTensor,
shape: bottleneckTensorDims,
name: "BottleneckInputPlaceholder");
if (useLearningRateScheduler)
_learningRateInput = tf.placeholder(tf.float32, null, name: "learningRateInputPlaceholder");
}
_labelTensor = tf.placeholder(tf.int64, new TensorShape(batch_size), name: labelColumn);
});
string layerName = "final_retrain_ops";
Tensor logits = null;
tf_with(tf.name_scope(layerName), scope =>
{
ResourceVariable layerWeights = null;
tf_with(tf.name_scope("weights"), delegate
{
var initialValue = tf.truncated_normal(new int[] { bottleneck_tensor_size, classCount },
stddev: 0.001f);
layerWeights = tf.Variable(initialValue, name: "final_weights");
VariableSummaries(layerWeights);
});
ResourceVariable layerBiases = null;
tf_with(tf.name_scope("biases"), delegate
{
TensorShape shape = new TensorShape(classCount);
layerBiases = tf.Variable(tf.zeros(shape), name: "final_biases");
VariableSummaries(layerBiases);
});
tf_with(tf.name_scope("Wx_plus_b"), delegate
{
var matmul = tf.matmul(isTraining ? _bottleneckInput : bottleneckTensor, layerWeights);
logits = matmul + layerBiases;
tf.summary.histogram("pre_activations", logits);
});
});
_softMaxTensor = tf.nn.softmax(logits, name: scoreColumnName);
tf.summary.histogram("activations", _softMaxTensor);
if (!isTraining)
return (null, null, _labelTensor, _softMaxTensor);
Tensor crossEntropyMean = null;
tf_with(tf.name_scope("cross_entropy"), delegate
{
crossEntropyMean = tf.losses.sparse_softmax_cross_entropy(
labels: _labelTensor, logits: logits);
});
tf.summary.scalar("cross_entropy", crossEntropyMean);
tf_with(tf.name_scope("train"), delegate
{
var optimizer = useLearningRateScheduler ? tf.train.GradientDescentOptimizer(_learningRateInput) :
tf.train.GradientDescentOptimizer(learningRate);
_trainStep = optimizer.minimize(crossEntropyMean);
});
return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor);
}
private void AddTransferLearningLayer(string labelColumn,
string scoreColumnName, float learningRate, bool useLearningRateScheduling, int classCount)
{
_bottleneckTensor = Graph.OperationByName(_bottleneckOperationName);
(_trainStep, _crossEntropy, _labelTensor, _softMaxTensor) =
AddFinalRetrainOps(classCount, labelColumn, scoreColumnName, _bottleneckTensor, true,
useLearningRateScheduling, learningRate);
}
private TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph(IHostEnvironment env, Architecture arch)
{
var modelFileName = ModelFileName[arch];
var modelFilePath = Path.Combine(_resourcePath, modelFileName);
int timeout = 10 * 60 * 1000;
DownloadIfNeeded(env, @"meta\" + modelFileName, _resourcePath, modelFileName, timeout);
return new TensorFlowSessionWrapper(GetSession(env, modelFilePath, true), modelFilePath);
}
/// <summary>
/// Trains a <see cref="ImageClassificationTrainer"/> using both training and validation data,
/// returns a <see cref="ImageClassificationModelParameters"/>.
/// </summary>
/// <param name="trainData">The training data set.</param>
/// <param name="validationData">The validation data set.</param>
public MulticlassPredictionTransformer<ImageClassificationModelParameters> Fit(
IDataView trainData, IDataView validationData) => TrainTransformer(trainData, validationData);
}
/// <summary>
/// Image Classification predictor. This class encapsulates the trained Deep Neural Network(DNN) model
/// and is used to score images.
/// </summary>
public sealed class ImageClassificationModelParameters : ModelParametersBase<VBuffer<float>>, IValueMapper, IDisposable
{
private bool _isDisposed;
internal const string LoaderSignature = "ImageClassificationPred";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "IMAGPRED",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(ImageClassificationModelParameters).Assembly.FullName);
}
private readonly VectorDataViewType _inputType;
private readonly VectorDataViewType _outputType;
private readonly int _classCount;
private readonly string _imagePreprocessorTensorInput;
private readonly string _imagePreprocessorTensorOutput;
private readonly string _graphInputTensor;
private readonly string _graphOutputTensor;
private readonly Session _session;
internal ImageClassificationModelParameters(IHostEnvironment env, Session session, int classCount,
string imagePreprocessorTensorInput, string imagePreprocessorTensorOutput, string graphInputTensor,
string graphOutputTensor) : base(env, LoaderSignature)
{
Host.AssertValue(session);
Host.Assert(classCount > 1);
Host.AssertNonEmpty(imagePreprocessorTensorInput);
Host.AssertNonEmpty(imagePreprocessorTensorOutput);
Host.AssertNonEmpty(graphInputTensor);
Host.AssertNonEmpty(graphOutputTensor);
_inputType = new VectorDataViewType(NumberDataViewType.Byte);
_outputType = new VectorDataViewType(NumberDataViewType.Single, classCount);
_classCount = classCount;
_session = session;
_imagePreprocessorTensorInput = imagePreprocessorTensorInput;
_imagePreprocessorTensorOutput = imagePreprocessorTensorOutput;
_graphInputTensor = graphInputTensor;
_graphOutputTensor = graphOutputTensor;
}
/// <summary> Return the type of prediction task.</summary>
private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
DataViewType IValueMapper.InputType => _inputType;
DataViewType IValueMapper.OutputType => _outputType;
private ImageClassificationModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx)
{
// *** Binary format ***
// int: _classCount
// string: _imagePreprocessorTensorInput
// string: _imagePreprocessorTensorOutput
// string: _graphInputTensor
// string: _graphOutputTensor
// Graph.
_classCount = ctx.Reader.ReadInt32();
_imagePreprocessorTensorInput = ctx.Reader.ReadString();
_imagePreprocessorTensorOutput = ctx.Reader.ReadString();
_graphInputTensor = ctx.Reader.ReadString();
_graphOutputTensor = ctx.Reader.ReadString();
byte[] modelBytes = null;
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
throw env.ExceptDecode();
_session = LoadTFSession(env, modelBytes);
_inputType = new VectorDataViewType(NumberDataViewType.Byte);
_outputType = new VectorDataViewType(NumberDataViewType.Single, _classCount);
}
internal static ImageClassificationModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new ImageClassificationModelParameters(env, ctx);
}
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
// int: _classCount
// string: _imagePreprocessorTensorInput
// string: _imagePreprocessorTensorOutput
// string: _graphInputTensor
// string: _graphOutputTensor
// Graph.
ctx.Writer.Write(_classCount);
ctx.Writer.Write(_imagePreprocessorTensorInput);
ctx.Writer.Write(_imagePreprocessorTensorOutput);
ctx.Writer.Write(_graphInputTensor);
ctx.Writer.Write(_graphOutputTensor);
using (var status = new Status())
using (var buffer = _session.graph.ToGraphDef(status))
{
ctx.SaveBinaryStream("TFModel", w =>
{
w.WriteByteArray(buffer.DangerousMemoryBlock.ToArray());
});
status.Check(true);
}
}
private class Classifier
{
private readonly Runner _runner;
private readonly ImageClassificationTrainer.ImageProcessor _imageProcessor;
public Classifier(ImageClassificationModelParameters model)
{
_runner = new Runner(model._session, new[] { model._graphInputTensor }, new[] { model._graphOutputTensor });
_imageProcessor = new ImageClassificationTrainer.ImageProcessor(model._session,
model._imagePreprocessorTensorInput, model._imagePreprocessorTensorOutput);
}
public void Score(in VBuffer<byte> image, Span<float> classProbabilities)
{
var processedTensor = _imageProcessor.ProcessImage(image);
if (processedTensor != null)
{
var outputTensor = _runner.AddInput(processedTensor, 0).Run();
outputTensor[0].CopyTo(classProbabilities);
outputTensor[0].Dispose();
processedTensor.Dispose();
}
}
}
ValueMapper<TSrc, TDst> IValueMapper.GetMapper<TSrc, TDst>()
{
Host.Check(typeof(TSrc) == typeof(VBuffer<byte>));
Host.Check(typeof(TDst) == typeof(VBuffer<float>));
_session.graph.as_default();
Classifier classifier = new Classifier(this);
ValueMapper<VBuffer<byte>, VBuffer<float>> del = (in VBuffer<byte> src, ref VBuffer<float> dst) =>
{
var editor = VBufferEditor.Create(ref dst, _classCount);
classifier.Score(src, editor.Values);
dst = editor.Commit();
};
return (ValueMapper<TSrc, TDst>)(Delegate)del;
}
public void Dispose()
{
if (_isDisposed)
return;
if (_session?.graph != IntPtr.Zero)
{
_session.graph.Dispose();
}
if (_session != null && _session != IntPtr.Zero)
{
_session.close();
}
_isDisposed = true;
}
}
}
|