File: DnnRetrainTransform.cs
Web Access
Project: src\src\Microsoft.ML.Vision\Microsoft.ML.Vision.csproj (Microsoft.ML.Vision)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.TensorFlow;
using Microsoft.ML.Transforms;
using NumSharp;
using Tensorflow;
using static Microsoft.ML.TensorFlow.TensorFlowUtils;
using static Tensorflow.Binding;
using Utils = Microsoft.ML.Internal.Utilities.Utils;
 
[assembly: LoadableClass(DnnRetrainTransformer.Summary, typeof(IDataTransform), typeof(DnnRetrainTransformer),
    typeof(DnnRetrainEstimator.Options), typeof(SignatureDataTransform), DnnRetrainTransformer.UserName, DnnRetrainTransformer.ShortName)]
 
[assembly: LoadableClass(DnnRetrainTransformer.Summary, typeof(IDataTransform), typeof(DnnRetrainTransformer), null, typeof(SignatureLoadDataTransform),
    DnnRetrainTransformer.UserName, DnnRetrainTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(DnnRetrainTransformer), null, typeof(SignatureLoadModel),
    DnnRetrainTransformer.UserName, DnnRetrainTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(DnnRetrainTransformer), null, typeof(SignatureLoadRowMapper),
    DnnRetrainTransformer.UserName, DnnRetrainTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// <see cref="ITransformer" /> for the <see cref="DnnRetrainEstimator"/>.
    /// </summary>
    internal sealed class DnnRetrainTransformer : RowToRowTransformerBase, IDisposable
    {
        private bool _isDisposed;
 
        private readonly IHostEnvironment _env;
        private readonly string _modelLocation;
        private readonly bool _isTemporarySavedModel;
        private readonly bool _addBatchDimensionInput;
        private readonly Session _session;
        private readonly DataViewType[] _outputTypes;
        private readonly TF_DataType[] _tfOutputTypes;
        private readonly TF_DataType[] _tfInputTypes;
        private readonly TensorShape[] _tfInputShapes;
        private readonly (Operation, int)[] _tfInputOperations;
        private readonly (Operation, int)[] _tfOutputOperations;
        private readonly TF_Output[] _tfInputNodes;
        private readonly TF_Output[] _tfOutputNodes;
        private Graph Graph => _session.graph;
        private readonly Dictionary<string, string> _idvToTfMapping;
        private readonly string[] _inputs;
        private readonly string[] _outputs;
 
        internal const string Summary = "Re-Trains Dnn models.";
        internal const string UserName = "DnnRtTransform";
        internal const string ShortName = "DnnRtTransform";
        internal const string LoaderSignature = "DnnRtTransform";
 
        internal static class DefaultModelFileNames
        {
            public const string VariablesFolder = "variables";
            public const string Index = "variables.index";
            public const string Data = "variables.data-00000-of-00001";
            public const string Graph = "saved_model.pb";
            public const string TmpMlnetModel = "mlnet_model";
        }
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "DNNTRANS",
                //verWrittenCur: 0x00010001, // Initial
                verWrittenCur: 0x00000001,
                verReadableCur: 0x00000001,
                verWeCanReadBack: 0x00000001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(DnnRetrainTransformer).Assembly.FullName);
        }
 
        // Factory method for SignatureLoadModel.
        private static DnnRetrainTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            // *** Binary format ***
            // byte: indicator for frozen models
            // byte: indicator for adding batch dimension in input
            // int: number of input columns
            // for each input column
            //   int: id of int column name
            // int: number of output columns
            // for each output column
            //   int: id of output column name
            // stream: tensorFlow model.
 
            GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput);
 
            if (isFrozen)
            {
                byte[] modelBytes = null;
                if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
                    throw env.ExceptDecode();
 
                return new DnnRetrainTransformer(env, TensorFlowUtils.LoadTFSession(env, modelBytes), outputs, inputs,
                    null, false, addBatchDimensionInput, 1);
            }
 
            var tempDirPath = Path.GetFullPath(Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, nameof(DnnRetrainTransformer) + "_" + Guid.NewGuid()));
            CreateFolderWithAclIfNotExists(env, tempDirPath);
            try
            {
                var load = ctx.TryLoadBinaryStream("TFSavedModel", br =>
                {
                    int count = br.ReadInt32();
                    for (int n = 0; n < count; n++)
                    {
                        string relativeFile = br.ReadString();
                        long fileLength = br.ReadInt64();
 
                        string fullFilePath = Path.Combine(tempDirPath, relativeFile);
                        string fullFileDir = Path.GetDirectoryName(fullFilePath);
                        if (fullFileDir != tempDirPath)
                        {
                            CreateFolderWithAclIfNotExists(env, fullFileDir);
                        }
                        using (var fs = new FileStream(fullFilePath, FileMode.Create, FileAccess.Write))
                        {
                            long actualRead = br.BaseStream.CopyRange(fs, fileLength);
                            env.Assert(actualRead == fileLength);
                        }
                    }
                });
 
                return new DnnRetrainTransformer(env, GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true,
                    addBatchDimensionInput, 1);
            }
            catch (Exception)
            {
                DeleteFolderWithRetries(env, tempDirPath);
                throw;
            }
        }
 
        // Factory method for SignatureDataTransform.
        internal static IDataTransform Create(IHostEnvironment env, DnnRetrainEstimator.Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckValue(input, nameof(input));
            env.CheckValue(options.InputColumns, nameof(options.InputColumns));
            env.CheckValue(options.OutputColumns, nameof(options.OutputColumns));
 
            return new DnnRetrainTransformer(env, options, input).MakeDataTransform(input);
        }
 
        internal DnnRetrainTransformer(IHostEnvironment env, DnnRetrainEstimator.Options options, IDataView input)
            : this(env, options, LoadDnnModel(env, options.ModelLocation), input)
        {
        }
 
        internal DnnRetrainTransformer(IHostEnvironment env, DnnRetrainEstimator.Options options, ML.TensorFlow.TensorFlowSessionWrapper tensorFlowModel, IDataView input, IDataView validationSet = null)
            : this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns,
                  options.ModelLocation, false, options.AddBatchDimensionInputs, options.BatchSize)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckValue(input, nameof(input));
            CheckTrainingParameters(options);
 
            if (!IsSavedModel(env, options.ModelLocation))
                throw env.ExceptNotSupp("TensorFlowTransform: Re-Training of TensorFlow model is only supported for un-frozen model.");
 
            TrainCore(options, input, validationSet);
        }
 
        private void CheckTrainingParameters(DnnRetrainEstimator.Options options)
        {
            Host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn));
            Host.CheckNonWhiteSpace(options.OptimizationOperation, nameof(options.OptimizationOperation));
            if (_session.graph.OperationByName(options.OptimizationOperation) == null)
                throw Host.ExceptParam(nameof(options.OptimizationOperation), $"Optimization operation '{options.OptimizationOperation}' does not exist in the model");
 
            Host.CheckNonWhiteSpace(options.TensorFlowLabel, nameof(options.TensorFlowLabel));
            if (_session.graph.OperationByName(options.TensorFlowLabel) == null)
                throw Host.ExceptParam(nameof(options.TensorFlowLabel), $"'{options.TensorFlowLabel}' does not exist in the model");
 
            Host.CheckNonWhiteSpace(options.SaveLocationOperation, nameof(options.SaveLocationOperation));
            if (_session.graph.OperationByName(options.SaveLocationOperation) == null)
                throw Host.ExceptParam(nameof(options.SaveLocationOperation), $"'{options.SaveLocationOperation}' does not exist in the model");
 
            Host.CheckNonWhiteSpace(options.SaveOperation, nameof(options.SaveOperation));
            if (_session.graph.OperationByName(options.SaveOperation) == null)
                throw Host.ExceptParam(nameof(options.SaveOperation), $"'{options.SaveOperation}' does not exist in the model");
 
            if (options.LossOperation != null)
            {
                Host.CheckNonWhiteSpace(options.LossOperation, nameof(options.LossOperation));
                if (_session.graph.OperationByName(options.LossOperation) == null)
                    throw Host.ExceptParam(nameof(options.LossOperation), $"'{options.LossOperation}' does not exist in the model");
            }
 
            if (options.MetricOperation != null)
            {
                Host.CheckNonWhiteSpace(options.MetricOperation, nameof(options.MetricOperation));
                if (_session.graph.OperationByName(options.MetricOperation) == null)
                    throw Host.ExceptParam(nameof(options.MetricOperation), $"'{options.MetricOperation}' does not exist in the model");
            }
 
            if (options.LearningRateOperation != null)
            {
                Host.CheckNonWhiteSpace(options.LearningRateOperation, nameof(options.LearningRateOperation));
                if (_session.graph.OperationByName(options.LearningRateOperation) == null)
                    throw Host.ExceptParam(nameof(options.LearningRateOperation), $"'{options.LearningRateOperation}' does not exist in the model");
            }
        }
 
        private (int, bool, TF_DataType, TensorShape) GetTrainingInputInfo(DataViewSchema inputSchema, string columnName, string tfNodeName, int batchSize)
        {
            if (!inputSchema.TryGetColumnIndex(columnName, out int inputColIndex))
                throw Host.Except($"Column {columnName} doesn't exist");
 
            var type = inputSchema[inputColIndex].Type;
            var isInputVector = type is VectorDataViewType;
 
            (Operation inputTensor, int index) = GetOperationFromName(tfNodeName, _session);
            var tfInput = new TF_Input(inputTensor, index);
            var tfInputType = inputTensor.OpType == "Placeholder" ? inputTensor.OutputType(index) :
                inputTensor.InputType(index);
            var tfInputShape = ((Tensor)inputTensor).TensorShape;
 
            var numInputDims = tfInputShape != null ? tfInputShape.ndim : -1;
            if (isInputVector && (tfInputShape == null || (numInputDims == 0)))
            {
                var vecType = (VectorDataViewType)type;
                var colTypeDims = new int[vecType.Dimensions.Length + 1];
                colTypeDims[0] = -1;
                for (int indexLocal = 0; indexLocal < vecType.Dimensions.Length; indexLocal += 1)
                    colTypeDims[indexLocal + 1] = vecType.Dimensions[indexLocal];
 
                tfInputShape = new TensorShape(colTypeDims);
            }
            if (numInputDims != -1)
            {
                var newShape = new int[numInputDims];
                var dims = tfInputShape.dims;
                newShape[0] = dims[0] == 0 || dims[0] == -1 ? batchSize : dims[0];
 
                for (int j = 1; j < numInputDims; j++)
                    newShape[j] = dims[j];
                tfInputShape = new TensorShape(newShape);
            }
 
            var expectedType = Tf2MlNetType(tfInputType);
            var actualType = type.GetItemType().RawType;
            if (type is KeyDataViewType && actualType == typeof(UInt32))
                actualType = typeof(Int64);
 
            if (actualType != expectedType.RawType)
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", columnName, expectedType.ToString(), type.ToString());
 
            return (inputColIndex, isInputVector, tfInputType, tfInputShape);
        }
 
        private void TrainCore(DnnRetrainEstimator.Options options, IDataView input, IDataView validationSet)
        {
            var inputsForTraining = new string[_inputs.Length + 1];
            var inputColIndices = new int[inputsForTraining.Length];
            var isInputVector = new bool[inputsForTraining.Length];
            var tfInputTypes = new TF_DataType[inputsForTraining.Length];
            var tfInputShapes = new TensorShape[inputsForTraining.Length];
 
            for (int i = 0; i < _inputs.Length; i++)
                inputsForTraining[i] = _idvToTfMapping[_inputs[i]];
 
            var inputSchema = input.Schema;
            for (int i = 0; i < inputsForTraining.Length - 1; i++)
                (inputColIndices[i], isInputVector[i], tfInputTypes[i], tfInputShapes[i]) =
                    GetTrainingInputInfo(inputSchema, _inputs[i], inputsForTraining[i], options.BatchSize);
 
            var index = inputsForTraining.Length - 1;
            inputsForTraining[index] = options.TensorFlowLabel;
 
            (inputColIndices[index], isInputVector[index], tfInputTypes[index], tfInputShapes[index]) =
                    GetTrainingInputInfo(inputSchema, options.LabelColumn, inputsForTraining[index], options.BatchSize);
 
            // Create graph inputs.
            Operation labelOp;
            int labelOpIdx;
            (labelOp, labelOpIdx) = GetOperationFromName(options.TensorFlowLabel, _session);
            TF_Output[] tfInputs;
            if (!string.IsNullOrEmpty(options.LearningRateOperation))
                tfInputs = new TF_Output[_tfInputNodes.Length + 2]; //Inputs + Label + Learning Rate.
            else
                tfInputs = new TF_Output[_tfInputNodes.Length + 1]; //Inputs + Label.
 
            Array.Copy(_tfInputNodes, tfInputs, _tfInputNodes.Length);
 
            tfInputs[_tfInputNodes.Length] = new TF_Output(labelOp, labelOpIdx);
            var lr = GetOperationFromName(options.LearningRateOperation, _session);
            tfInputs[_tfInputNodes.Length + 1] = new TF_Output(lr.Item1, lr.Item2);
 
            // Create graph operations.
            IntPtr[] ops = null;
            if (options.OptimizationOperation != null)
                ops = new[] { c_api.TF_GraphOperationByName(Graph, options.OptimizationOperation) };
 
            // Instantiate the graph.
            string[] outputs = null;
            if (options.LossOperation != null && options.MetricOperation != null)
                outputs = new[] { options.LossOperation, options.MetricOperation };
            else if (options.LossOperation != null)
                outputs = new[] { options.LossOperation };
            else if (options.MetricOperation != null)
                outputs = new[] { options.MetricOperation };
 
            Runner runner = new Runner(_session, new[] { options.LearningRateOperation }.Concat(inputsForTraining).ToArray(),
                outputs, new[] { options.OptimizationOperation }).AddInput(new Tensor(options.LearningRate), 0);
 
            var cols = input.Schema.Where(c => inputColIndices.Contains(c.Index));
 
            for (int epoch = 0; epoch < options.Epoch; epoch++)
            {
                using (var cursor = input.GetRowCursor(cols))
                {
                    var srcTensorGetters = GetTensorValueGetters(cursor, inputColIndices, isInputVector, tfInputTypes, tfInputShapes);
                    bool isDataLeft = false;
                    using (var ch = Host.Start("Training TensorFlow model..."))
                    using (var pch = Host.StartProgressChannel("TensorFlow training progress..."))
                    {
                        float loss = 0;
                        float metric = 0;
                        pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch));
 
                        while (cursor.MoveNext())
                        {
                            for (int i = 0; i < inputsForTraining.Length; i++)
                            {
                                isDataLeft = true;
                                srcTensorGetters[i].BufferTrainingData();
                            }
 
                            if (((cursor.Position + 1) % options.BatchSize) == 0)
                            {
                                isDataLeft = false;
                                var (l, m) = ExecuteGraphAndRetrieveMetrics(inputsForTraining, srcTensorGetters, runner);
                                loss += l;
                                metric += m;
                            }
                        }
                        if (isDataLeft)
                        {
                            isDataLeft = false;
                            ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize);
                        }
                        pch.Checkpoint(new double?[] { loss, metric });
                    }
                }
            }
 
            UpdateModelOnDisk(options.ModelLocation, options);
        }
 
        private (float loss, float metric) ExecuteGraphAndRetrieveMetrics(
            string[] inputs,
            ITensorValueGetter[] srcTensorGetters,
            Runner runner)
        {
            float loss = 0.0f;
            float metric = 0.0f;
            for (int i = 0; i < inputs.Length; i++)
                runner.AddInput(srcTensorGetters[i].GetBufferedBatchTensor(), i + 1);
 
            Tensor[] tensor = runner.Run();
            if (tensor.Length > 0 && tensor[0] != IntPtr.Zero)
            {
                tensor[0].ToScalar<float>(ref loss);
                tensor[0].Dispose();
            }
 
            if (tensor.Length > 1 && tensor[1] != IntPtr.Zero)
            {
                tensor[1].ToScalar<float>(ref metric);
                tensor[1].Dispose();
            }
 
            return (loss, metric);
        }
 
        /// <summary>
        /// Updates the model on the disk.
        /// After retraining Session and Graphs are both up-to-date
        /// However model on disk is not which is used to serialzed to ML.Net stream
        /// </summary>
        private void UpdateModelOnDisk(string modelDir, DnnRetrainEstimator.Options options)
        {
            try
            {
                // Save the model on disk
                var path = Path.Combine(modelDir, DefaultModelFileNames.TmpMlnetModel);
                //var input = GetOperationFromName(options.SaveLocationOperation, Session);
                var runner = new Runner(_session, new[] { options.SaveLocationOperation },
                    null, new[] { options.SaveOperation }).AddInput(new Tensor(path), 0);
 
                runner.Run();
                // Preserve original files
                var variablesPath = Path.Combine(modelDir, DefaultModelFileNames.VariablesFolder);
                var archivePath = Path.Combine(variablesPath + "-" + Guid.NewGuid().ToString());
                Directory.CreateDirectory(archivePath);
                foreach (var f in Directory.GetFiles(variablesPath))
                    File.Copy(f, Path.Combine(archivePath, Path.GetFileName(f)));
 
                string[] modelFilePaths = null;
 
                // There are two ways parameters are saved depending on
                // either `saver_def = tf.train.Saver().as_saver_def()` was called in Python before `tf.saved_model.simple_save` or not.
                // If `saver_def = tf.train.Saver().as_saver_def()` was called files are saved in top directory.
                // If not then temporary directory is created in current directory which starts with `mlnet_model`
                // and files are saved there.
                var tmpParamDir = Directory.GetDirectories(modelDir, DefaultModelFileNames.TmpMlnetModel + "*");
                if (tmpParamDir != null && tmpParamDir.Length > 0)
                    modelFilePaths = Directory.GetFiles(tmpParamDir[0]);
                else
                    modelFilePaths = Directory.GetFiles(modelDir, DefaultModelFileNames.TmpMlnetModel + "*");
 
                foreach (var file in modelFilePaths)
                {
                    if (file.EndsWith(".data-00000-of-00001"))
                    {
                        var destination = Path.Combine(variablesPath, DefaultModelFileNames.Data);
                        if (File.Exists(destination))
                            File.Delete(destination);
                        Directory.Move(file, destination);
                    }
                    if (file.EndsWith(".index"))
                    {
                        var destination = Path.Combine(variablesPath, DefaultModelFileNames.Index);
                        if (File.Exists(destination))
                            File.Delete(destination);
                        Directory.Move(file, destination);
                    }
                }
 
                if (tmpParamDir != null && tmpParamDir.Length > 0)
                    DeleteFolderWithRetries(Host, tmpParamDir[0]);
            }
            catch (Exception e)
            {
                throw Host.ExceptIO(e, "Error serializing TensorFlow retrained model to disk.");
            }
        }
 
        private static ITensorValueGetter CreateTensorValueGetter<T>(DataViewRow input, bool isVector, int colIndex, TensorShape tfShape, bool keyType = false)
        {
            if (isVector)
                return new TensorValueGetterVec<T>(input, colIndex, tfShape);
            return new TensorValueGetter<T>(input, colIndex, tfShape, keyType);
        }
 
        private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, TF_DataType tfType, bool isVector, int colIndex, TensorShape tfShape)
        {
            var type = Tf2MlNetType(tfType);
            if (input.Schema[colIndex].Type is KeyDataViewType && type.RawType == typeof(Int64))
                return Utils.MarshalInvoke(CreateTensorValueGetter<int>, typeof(UInt32), input, isVector, colIndex, tfShape, true);
 
            return Utils.MarshalInvoke(CreateTensorValueGetter<int>, type.RawType, input, isVector, colIndex, tfShape, false);
        }
 
        private static ITensorValueGetter[] GetTensorValueGetters(
            DataViewRow input,
            int[] inputColIndices,
            bool[] isInputVector,
            TF_DataType[] tfInputTypes,
            TensorShape[] tfInputShapes)
        {
            var srcTensorGetters = new ITensorValueGetter[inputColIndices.Length];
            for (int i = 0; i < inputColIndices.Length; i++)
            {
                int colIndex = inputColIndices[i];
                srcTensorGetters[i] = CreateTensorValueGetter(input, tfInputTypes[i], isInputVector[i], colIndex, tfInputShapes[i]);
            }
            return srcTensorGetters;
        }
 
        // Factory method for SignatureLoadDataTransform.
        private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
            => Create(env, ctx).MakeDataTransform(input);
 
        // Factory method for SignatureLoadRowMapper.
        private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
            => Create(env, ctx).MakeRowMapper(inputSchema);
 
        private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs,
            out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput)
        {
            isFrozen = ctx.Reader.ReadBoolByte();
            addBatchDimensionInput = ctx.Reader.ReadBoolByte();
 
            var numInputs = ctx.Reader.ReadInt32();
            env.CheckDecode(numInputs > 0);
            inputs = new string[numInputs];
            for (int j = 0; j < inputs.Length; j++)
                inputs[j] = ctx.LoadNonEmptyString();
 
            var numOutputs = ctx.Reader.ReadInt32();
            env.CheckDecode(numOutputs > 0);
            outputs = new string[numOutputs];
            for (int j = 0; j < outputs.Length; j++)
                outputs[j] = ctx.LoadNonEmptyString();
        }
 
        internal DnnRetrainTransformer(IHostEnvironment env, Session session, string[] outputColumnNames,
            string[] inputColumnNames, string modelLocation, bool isTemporarySavedModel,
            bool addBatchDimensionInput, int batchSize)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnRetrainTransformer)))
 
        {
            Host.CheckValue(session, nameof(session));
            Host.CheckNonEmpty(inputColumnNames, nameof(inputColumnNames));
            Host.CheckNonEmpty(outputColumnNames, nameof(outputColumnNames));
 
            _env = env;
            _session = session;
            _modelLocation = Path.IsPathRooted(modelLocation) ? modelLocation : Path.Combine(Directory.GetCurrentDirectory(), modelLocation);
            _isTemporarySavedModel = isTemporarySavedModel;
            _addBatchDimensionInput = addBatchDimensionInput;
            _inputs = inputColumnNames;
            _outputs = outputColumnNames;
            _idvToTfMapping = new Dictionary<string, string>();
 
            foreach (var x in _inputs)
                _idvToTfMapping[x] = x;
 
            foreach (var x in _outputs)
                _idvToTfMapping[x] = x;
 
            (_tfOutputTypes, _outputTypes, _tfOutputOperations) = GetOutputInfo(Host, _session, _outputs);
 
            (_tfInputTypes, _tfInputShapes, _tfInputOperations) = GetInputInfo(Host, _session, _inputs.Select(x => _idvToTfMapping[x]).ToArray(), batchSize);
 
            _tfInputNodes = new TF_Output[_inputs.Length];
            _tfOutputNodes = new TF_Output[_outputs.Length];
 
            for (int index = 0; index < _tfInputOperations.Length; index += 1)
                _tfInputNodes[index] = new TF_Output(_tfInputOperations[index].Item1, _tfInputOperations[index].Item2);
 
            for (int index = 0; index < _tfOutputOperations.Length; index += 1)
                _tfOutputNodes[index] = new TF_Output(_tfOutputOperations[index].Item1, _tfOutputOperations[index].Item2);
        }
 
        private static (Operation, int) GetOperationFromName(string operation, Session session)
        {
            var p = operation.IndexOf(':');
 
            if (p != -1 && p != operation.Length - 1)
            {
                var op = operation.Substring(0, p);
                if (int.TryParse(operation.Substring(p + 1), out var idx))
                {
 
                    return (session.graph.OperationByName(op), idx);
                }
            }
            return (session.graph.OperationByName(operation), 0);
        }
 
        internal static (TF_DataType[] tfInputTypes, TensorShape[] tfInputShapes, (Operation, int)[]) GetInputInfo(IHost host, Session session, string[] inputs, int batchSize = 1)
        {
            var tfInputTypes = new TF_DataType[inputs.Length];
            var tfInputShapes = new TensorShape[inputs.Length];
            var tfInputOperations = new (Operation, int)[inputs.Length];
 
            int index = 0;
            foreach (var input in inputs)
            {
                host.CheckNonWhiteSpace(input, nameof(inputs));
                (Operation inputTensor, int inputTensorIndex) = GetOperationFromName(input, session);
 
                if (inputTensor == null)
                    throw host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model");
 
                TF_DataType tfInputType = string.Compare(inputTensor.OpType, "PlaceHolder", true) == 0 ? inputTensor.OutputType(inputTensorIndex) : inputTensor.InputType(index);
                if (!IsTypeSupported(tfInputType))
                    throw host.ExceptParam(nameof(session), $"Input type '{tfInputType}' of input column '{input}' is not supported in TensorFlow");
 
                tfInputTypes[index] = tfInputType;
                tfInputShapes[index] = ((Tensor)inputTensor).TensorShape;
                tfInputOperations[index] = (inputTensor, inputTensorIndex);
                index++;
            }
 
            return (tfInputTypes, tfInputShapes, tfInputOperations);
        }
 
        internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status status = null)
        {
            if (graph == IntPtr.Zero)
                throw new ObjectDisposedException(nameof(graph));
 
            var cstatus = status == null ? new Status() : status;
            var n = c_api.TF_GraphGetTensorNumDims(graph, output, cstatus.Handle);
 
            cstatus.Check();
 
            if (n == -1)
                return new TensorShape(new int[0]);
 
            var dims = new long[n];
            c_api.TF_GraphGetTensorShape(graph, output, dims, dims.Length, cstatus.Handle);
            cstatus.Check();
            return new TensorShape(dims.Select(x => (int)x).ToArray());
        }
 
        internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Operation, int)[]) GetOutputInfo(IHost host, Session session, string[] outputs)
        {
            var tfOutputTypes = new TF_DataType[outputs.Length];
            var outputTypes = new DataViewType[outputs.Length];
            var newNames = new HashSet<string>();
            var tfOutputOperations = new (Operation, int)[outputs.Length];
 
            for (int i = 0; i < outputs.Length; i++)
            {
                host.CheckNonWhiteSpace(outputs[i], nameof(outputs));
                if (!newNames.Add(outputs[i]))
                    throw host.ExceptParam(nameof(outputs), $"Output column '{outputs[i]}' specified multiple times");
 
                (Tensor outputTensor, int outputIndex) = GetOperationFromName(outputs[i], session);
                if (outputTensor == null)
                    throw host.ExceptParam(nameof(outputs), $"Output column '{outputs[i]}' does not exist in the model");
 
                var tfOutputType = ((Operation)outputTensor).OutputType(outputIndex);
                var shape = GetTensorShape(new TF_Output((Operation)outputTensor, outputIndex), session.graph);
 
                // The transformer can only retrieve the output as fixed length vector with shape of kind [-1, d1, d2, d3, ...]
                // i.e. the first dimension (if unknown) is assumed to be batch dimension.
                // If there are other dimension that are unknown the transformer will return a variable length vector.
                // This is the work around in absence of reshape transformer.
                int[] dims = shape.ndim > 0 ? shape.dims.Skip(shape.dims[0] == -1 ? 1 : 0).ToArray() : new[] { 0 };
                for (int j = 0; j < dims.Length; j++)
                    dims[j] = dims[j] == -1 ? 0 : dims[j];
                if (dims == null || dims.Length == 0)
                {
                    dims = new[] { 1 };
                    outputTypes[i] = Tf2MlNetType(tfOutputType);
                }
                else
                {
                    var type = Tf2MlNetType(tfOutputType);
                    outputTypes[i] = new VectorDataViewType(type, dims);
                }
 
                tfOutputTypes[i] = tfOutputType;
                tfOutputOperations[i] = (outputTensor, outputIndex);
            }
 
            return (tfOutputTypes, outputTypes, tfOutputOperations);
        }
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.AssertValue(ctx);
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // byte: indicator for frozen models
            // byte: indicator for adding batch dimension in input
            // int: number of input columns
            // for each input column
            //   int: id of int column name
            // int: number of output columns
            // for each output column
            //   int: id of output column name
            // stream: tensorFlow model.
            var isFrozen = !IsSavedModel(_env, _modelLocation);
            ctx.Writer.WriteBoolByte(isFrozen);
            ctx.Writer.WriteBoolByte(_addBatchDimensionInput);
 
            Host.AssertNonEmpty(_inputs);
            ctx.Writer.Write(_inputs.Length);
            foreach (var colName in _inputs)
                ctx.SaveNonEmptyString(colName);
 
            Host.AssertNonEmpty(_outputs);
            ctx.Writer.Write(_outputs.Length);
            foreach (var colName in _outputs)
                ctx.SaveNonEmptyString(colName);
 
            ctx.SaveBinaryStream("TFSavedModel", w =>
            {
                // only these files need to be saved.
                string[] modelFilePaths =
                {
                    Path.Combine(_modelLocation, DefaultModelFileNames.Graph),
                    Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Data),
                    Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index),
                };
 
                w.Write(modelFilePaths.Length);
 
                foreach (var fullPath in modelFilePaths)
                {
                    var relativePath = fullPath.Substring(_modelLocation.Length + 1);
                    w.Write(relativePath);
 
                    using (var fs = new FileStream(fullPath, FileMode.Open))
                    {
                        long fileLength = fs.Length;
                        w.Write(fileLength);
                        long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
                        Host.Assert(actualWritten == fileLength);
                    }
                }
            });
        }
 
        public void Dispose()
        {
            if (_isDisposed)
                return;
 
            // Ensure that the Session is not null and it's handle is not Zero, as it may have already been disposed/finalized.
            // Technically we shouldn't be calling this if disposing == false, since we're running in finalizer
            // and the GC doesn't guarantee ordering of finalization of managed objects, but we have to make sure
            // that the Session is closed before deleting our temporary directory.
            try
            {
                if (_session != null && _session != IntPtr.Zero)
                {
                    if (_session.graph != null)
                        _session.graph.Dispose();
                    _session.close();
                }
            }
            finally
            {
                if (IsSavedModel(_env, _modelLocation) && _isTemporarySavedModel)
                {
                    DeleteFolderWithRetries(Host, _modelLocation);
                }
 
                _isDisposed = true;
            }
        }
 
        private sealed class Mapper : MapperBase
        {
            private readonly DnnRetrainTransformer _parent;
            private readonly int[] _inputColIndices;
            private readonly bool[] _isInputVector;
            private readonly TensorShape[] _fullySpecifiedShapes;
            private readonly ConcurrentBag<Runner> _runners;
 
            public Mapper(DnnRetrainTransformer parent, DataViewSchema inputSchema) :
                   base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
            {
                Host.CheckValue(parent, nameof(parent));
                _parent = parent;
                _inputColIndices = new int[_parent._inputs.Length];
                _isInputVector = new bool[_parent._inputs.Length];
                _fullySpecifiedShapes = new TensorShape[_parent._inputs.Length];
                for (int i = 0; i < _parent._inputs.Length; i++)
                {
                    if (!inputSchema.TryGetColumnIndex(_parent._inputs[i], out _inputColIndices[i]))
                        throw Host.ExceptSchemaMismatch(nameof(InputSchema), "source", _parent._inputs[i]);
 
                    var type = inputSchema[_inputColIndices[i]].Type;
                    if (type is VectorDataViewType vecType && vecType.Size == 0)
                        throw Host.Except("Variable length input columns not supported");
 
                    _isInputVector[i] = type is VectorDataViewType;
                    if (!_isInputVector[i])
                        throw Host.Except("Non-vector columns are not supported and should be loaded as vector columns of size 1");
                    vecType = (VectorDataViewType)type;
                    var expectedType = Tf2MlNetType(_parent._tfInputTypes[i]);
                    if (type.GetItemType() != expectedType)
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent._inputs[i], expectedType.ToString(), type.ToString());
                    var originalShape = _parent._tfInputShapes[i];
                    var shape = originalShape.dims;
 
                    var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
                    if (shape == null || (shape.Length == 0))
                        _fullySpecifiedShapes[i] = new TensorShape(colTypeDims);
                    else
                    {
                        // If the column is one dimension we make sure that the total size of the TF shape matches.
                        // Compute the total size of the known dimensions of the shape.
                        int valCount = 1;
                        int numOfUnkDim = 0;
                        foreach (var s in shape)
                        {
                            if (s > 0)
                                valCount *= s;
                            else
                                numOfUnkDim++;
                        }
                        // The column length should be divisible by this, so that the other dimensions can be integral.
                        int typeValueCount = type.GetValueCount();
                        if (typeValueCount % valCount != 0)
                            throw Contracts.Except($"Input shape mismatch: Input '{_parent._inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");
 
                        // If the shape is multi-dimensional, we should be able to create the length of the vector by plugging
                        // in a single value for the unknown shapes. For example, if the shape is [?,?,3], then there should exist a value
                        // d such that d*d*3 is equal to the length of the input column.
                        var d = numOfUnkDim > 0 ? Math.Pow(typeValueCount / valCount, 1.0 / numOfUnkDim) : 0;
                        if (d - (int)d != 0)
                            throw Contracts.Except($"Input shape mismatch: Input '{_parent._inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");
 
                        // Fill in the unknown dimensions.
                        var originalShapeDims = originalShape.dims;
                        var originalShapeNdim = originalShape.ndim;
                        var l = new int[originalShapeNdim];
                        for (int ishape = 0; ishape < originalShapeNdim; ishape++)
                            l[ishape] = originalShapeDims[ishape] == -1 ? (int)d : originalShapeDims[ishape];
                        _fullySpecifiedShapes[i] = new TensorShape(l);
                    }
 
                    if (_parent._addBatchDimensionInput)
                    {
                        var l = new int[_fullySpecifiedShapes[i].ndim + 1];
                        l[0] = 1;
                        for (int ishape = 1; ishape < l.Length; ishape++)
                            l[ishape] = _fullySpecifiedShapes[i].dims[ishape - 1];
                        _fullySpecifiedShapes[i] = new TensorShape(l);
                    }
                }
 
                _runners = new ConcurrentBag<Runner>();
            }
 
            private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
 
            private class OutputCache
            {
                public long Position;
                public Dictionary<string, Tensor> Outputs;
                public OutputCache()
                {
                    Position = -1;
                    Outputs = new Dictionary<string, Tensor>();
                }
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
            {
                disposer = null;
                Host.AssertValue(input);
 
                var outputCache = new OutputCache();
                var activeOutputColNames = _parent._outputs.Where((x, i) => activeOutput(i)).ToArray();
 
                var type = Tf2MlNetType(_parent._tfOutputTypes[iinfo]).RawType;
                Host.Assert(type == _parent._outputTypes[iinfo].GetItemType().RawType);
                var srcTensorGetters = GetTensorValueGetters(input, _inputColIndices, _isInputVector, _parent._tfInputTypes, _fullySpecifiedShapes);
                return Utils.MarshalInvoke(MakeGetter<int>, type, input, iinfo, srcTensorGetters, activeOutputColNames, outputCache);
            }
 
            private Delegate MakeGetter<T>(DataViewRow input, int iinfo, ITensorValueGetter[] srcTensorGetters, string[] activeOutputColNames, OutputCache outputCache) where T : unmanaged
            {
                Host.AssertValue(input);
 
                if (_parent._outputTypes[iinfo].IsStandardScalar())
                {
                    ValueGetter<T> valuegetter = (ref T dst) =>
                    {
                        UpdateCacheIfNeeded(input.Position, srcTensorGetters, activeOutputColNames, outputCache);
 
                        var tensor = outputCache.Outputs[_parent._outputs[iinfo]];
                        tensor.ToScalar<T>(ref dst);
                    };
                    return valuegetter;
                }
                else
                {
                    if (_parent._tfOutputTypes[iinfo] == TF_DataType.TF_STRING)
                    {
                        ValueGetter<VBuffer<T>> valuegetter = (ref VBuffer<T> dst) =>
                        {
                            UpdateCacheIfNeeded(input.Position, srcTensorGetters, activeOutputColNames, outputCache);
 
                            var tensor = outputCache.Outputs[_parent._outputs[iinfo]];
                            var tensorSize = tensor.TensorShape.dims.Where(x => x > 0).Aggregate((x, y) => x * y);
 
                            var editor = VBufferEditor.Create(ref dst, (int)tensorSize);
                            FetchStringData(tensor, editor.Values);
                            dst = editor.Commit();
                        };
                        return valuegetter;
                    }
                    else
                    {
                        ValueGetter<VBuffer<T>> valuegetter = (ref VBuffer<T> dst) =>
                        {
                            UpdateCacheIfNeeded(input.Position, srcTensorGetters, activeOutputColNames, outputCache);
 
                            var tensor = outputCache.Outputs[_parent._outputs[iinfo]];
                            var tensorSize = tensor.TensorShape.dims.Where(x => x > 0).Aggregate((x, y) => x * y);
 
                            var editor = VBufferEditor.Create(ref dst, (int)tensorSize);
 
                            tensor.CopyTo<T>(editor.Values);
                            dst = editor.Commit();
                        };
                        return valuegetter;
                    }
                }
            }
 
            private void UpdateCacheIfNeeded(long position, ITensorValueGetter[] srcTensorGetters, string[] activeOutputColNames, OutputCache outputCache)
            {
                if (outputCache.Position != position)
                {
                    if (_parent.Graph.graph_key != tf.get_default_graph().graph_key)
                        _parent._session.graph.as_default();
 
                    Runner runner = new Runner(_parent._session,
                        _parent._inputs.Select(x => _parent._idvToTfMapping[x]).ToArray(),
                        _parent._outputs.Select(x => _parent._idvToTfMapping[x]).ToArray());
 
                    // Feed the inputs.
                    for (int i = 0; i < _parent._inputs.Length; i++)
                        runner.AddInput(srcTensorGetters[i].GetTensor(), 0);
 
                    // Execute the graph.
                    var tensors = runner.Run();
                    Contracts.Assert(tensors.Length > 0);
 
                    for (int j = 0; j < activeOutputColNames.Length; j++)
                        outputCache.Outputs[activeOutputColNames[j]] = tensors[j];
 
                    outputCache.Position = position;
                }
            }
 
            private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
            {
                return col => Enumerable.Range(0, _parent._outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col);
            }
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
                var info = new DataViewSchema.DetachedColumn[_parent._outputs.Length];
                for (int i = 0; i < _parent._outputs.Length; i++)
                    info[i] = new DataViewSchema.DetachedColumn(_parent._outputs[i], _parent._outputTypes[i], null);
                return info;
            }
        }
 
        private interface ITensorValueGetter
        {
            Tensor GetTensor();
 
            void BufferTrainingData();
 
            Tensor GetBufferedBatchTensor();
        }
 
        private class TensorValueGetter<T> : ITensorValueGetter
        {
            private readonly ValueGetter<T> _srcgetter;
            private readonly T[] _bufferedData;
            private readonly Int64[] _bufferedDataLong;
            private readonly TensorShape _tfShape;
            private int _position;
            private readonly bool _keyType;
            private readonly long[] _dims;
 
            public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape, bool keyType = false)
            {
                _srcgetter = input.GetGetter<T>(input.Schema[colIndex]);
                _tfShape = tfShape;
                long size = 0;
                _position = 0;
                if (tfShape.dims.Length != 0)
                {
                    size = 1;
                    foreach (var dim in tfShape.dims)
                        size *= dim;
                    _dims = _tfShape.dims.Select(x => (long)x).ToArray();
                }
                if (keyType)
                    _bufferedDataLong = new long[size];
                else
                    _bufferedData = new T[size];
                _keyType = keyType;
            }
 
            public Tensor GetTensor()
            {
                var scalar = default(T);
                _srcgetter(ref scalar);
                if (_keyType)
                {
                    var tensor = new Tensor(new[] { Convert.ToInt64(scalar) - 1 });
                    tensor.set_shape(_tfShape);
                    return tensor;
                }
                else
                {
                    var tensor = new Tensor(new[] { scalar });
                    tensor.set_shape(_tfShape);
                    return tensor;
                }
            }
 
            public void BufferTrainingData()
            {
                if (_keyType)
                {
                    var scalar = default(T);
                    _srcgetter(ref scalar);
                    _bufferedDataLong[_position++] = Convert.ToInt64(scalar) - 1;
                }
                else
                {
                    var scalar = default(T);
                    _srcgetter(ref scalar);
                    _bufferedData[_position++] = scalar;
                }
            }
 
            public Tensor GetBufferedBatchTensor()
            {
                if (_keyType)
                {
                    var tensor = new Tensor(_bufferedDataLong, _dims, TF_DataType.TF_INT64);
                    _position = 0;
                    return tensor;
                }
                else
                {
                    var tensor = TensorFlowUtils.CastDataAndReturnAsTensor(_bufferedData, _tfShape);
                    _position = 0;
                    return tensor;
                }
            }
        }
 
        private class TensorValueGetterVec<T> : ITensorValueGetter
        {
            private readonly ValueGetter<VBuffer<T>> _srcgetter;
            private readonly TensorShape _tfShape;
            private VBuffer<T> _vBuffer;
            private T[] _denseData;
            private T[] _bufferedData;
            private int _position;
            private readonly long[] _dims;
            private readonly long _bufferedDataSize;
 
            public TensorValueGetterVec(DataViewRow input, int colIndex, TensorShape tfShape)
            {
                _srcgetter = input.GetGetter<VBuffer<T>>(input.Schema[colIndex]);
                _tfShape = tfShape;
                _vBuffer = default;
                _denseData = default;
 
                long size = 0;
                _position = 0;
                if (tfShape.dims.Length != 0)
                {
                    size = 1;
                    foreach (var dim in tfShape.dims)
                        size *= dim;
                }
                _bufferedData = new T[size];
                _bufferedDataSize = size;
                if (_tfShape.dims != null)
                    _dims = _tfShape.dims.Select(x => (long)x).ToArray();
            }
 
            public Tensor GetTensor()
            {
                _srcgetter(ref _vBuffer);
 
                // _denseData.Length can be greater than _vBuffer.Length sometime after
                // Utils.EnsureSize is executed. Use _vBuffer.Length to access the elements in _denseData.
                // This is done to reduce memory allocation every time tensor is created.
                _denseData = new T[_vBuffer.Length];
                _vBuffer.CopyTo(_denseData);
                return TensorFlowUtils.CastDataAndReturnAsTensor(_denseData, _tfShape);
            }
 
            public void BufferTrainingData()
            {
                _srcgetter(ref _vBuffer);
                _vBuffer.CopyTo(_bufferedData, _position);
                _position += _vBuffer.Length;
            }
 
            public Tensor GetBufferedBatchTensor()
            {
                _position = 0;
                var tensor = TensorFlowUtils.CastDataAndReturnAsTensor(_bufferedData, _tfShape);
                _bufferedData = new T[_bufferedDataSize];
                return tensor;
            }
        }
    }
 
    internal sealed class DnnRetrainEstimator : IEstimator<DnnRetrainTransformer>
    {
        /// <summary>
        /// The options for the <see cref="DnnRetrainTransformer"/>.
        /// </summary>
        internal sealed class Options : TransformInputBase
        {
            /// <summary>
            /// Location of the TensorFlow model.
            /// </summary>
            [Argument(ArgumentType.Required, HelpText = "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", SortOrder = 0)]
            public string ModelLocation;
 
            /// <summary>
            /// The names of the model inputs.
            /// </summary>
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)]
            public string[] InputColumns;
 
            /// <summary>
            /// The names of the requested model outputs.
            /// </summary>
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 2)]
            public string[] OutputColumns;
 
            /// <summary>
            /// The name of the label column in <see cref="IDataView"/> that will be mapped to label node in TensorFlow model.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Training labels.", ShortName = "label", SortOrder = 4)]
            public string LabelColumn;
 
            /// <summary>
            /// The name of the label in TensorFlow model.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "TensorFlow label node.", ShortName = "TFLabel", SortOrder = 5)]
            public string TensorFlowLabel;
 
            /// <summary>
            /// Name of the operation in TensorFlow graph that is used for optimizing parameters in the graph.
            /// Usually it is the name specified in the minimize method of optimizer in python
            /// e.g. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost, name = "SGDOptimizer").
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the optimization operation in the TensorFlow graph.", ShortName = "OptimizationOp", SortOrder = 6)]
            public string OptimizationOperation;
 
            /// <summary>
            /// The name of the operation in the TensorFlow graph to compute training loss (Optional).
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the operation in the TensorFlow graph to compute training loss (Optional)", ShortName = "LossOp", SortOrder = 7)]
            public string LossOperation;
 
            /// <summary>
            /// The name of the operation in the TensorFlow graph to compute performance metric during training (Optional).
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the operation in the TensorFlow graph to compute performance metric during training (Optional)", ShortName = "MetricOp", SortOrder = 8)]
            public string MetricOperation;
 
            /// <summary>
            /// Number of samples to use for mini-batch training.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Number of samples to use for mini-batch training.", SortOrder = 9)]
            public int BatchSize = 64;
 
            /// <summary>
            /// Number of training iterations.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations.", SortOrder = 10)]
            public int Epoch = 5;
 
            /// <summary>
            /// The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).", SortOrder = 11)]
            public string LearningRateOperation;
 
            /// <summary>
            /// Learning rate to use during optimization.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)]
            public float LearningRate = 0.01f;
 
            /// <summary>
            /// Name of the input in TensorFlow graph that specify the location for saving/restoring models to/from disk.
            /// This parameter is set by different kinds of 'Savers' in TensorFlow and users don't have control over this.
            /// Therefore, its highly unlikely that this parameter is changed from its default value of 'save/Const'.
            /// Please change it cautiously if you need to.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the input in TensorFlow graph that specify the location for saving/restoring models from disk.", SortOrder = 13)]
            public string SaveLocationOperation = "save/Const";
 
            /// <summary>
            /// Name of the operation in TensorFlow graph that is used for saving/restoring models to/from disk.
            /// This parameter is set by different kinds of 'Savers' in TensorFlow and users don't have control over this.
            /// Therefore, its highly unlikely that this parameter is changed from its default value of 'save/control_dependency'.
            /// Please change it cautiously if you need to.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the input in TensorFlow graph that specify the location for saving/restoring models from disk.", SortOrder = 14)]
            public string SaveOperation = "save/control_dependency";
 
            /// <summary>
            /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
            /// </summary>
            /// <remarks>
            /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.
            /// In this case, there is no way to induce shape from the model's inputs or input data.
            /// </remarks>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].", SortOrder = 16)]
            public bool AddBatchDimensionInputs = false;
        }
 
        private readonly IHost _host;
        private readonly Options _options;
        private readonly ML.TensorFlow.TensorFlowSessionWrapper _tensorFlowModel;
        private readonly TF_DataType[] _tfInputTypes;
        private readonly DataViewType[] _outputTypes;
        private DnnRetrainTransformer _transformer;
 
        internal DnnRetrainEstimator(IHostEnvironment env, Options options, ML.TensorFlow.TensorFlowSessionWrapper tensorFlowModel)
        {
            _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnRetrainEstimator));
            _options = options;
            _tensorFlowModel = tensorFlowModel;
            var inputTuple = DnnRetrainTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns);
            _tfInputTypes = inputTuple.tfInputTypes;
            _outputTypes = DnnRetrainTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns).outputTypes;
        }
 
        private static Options CreateArguments(ML.TensorFlow.TensorFlowSessionWrapper tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput)
        {
            var options = new Options();
            options.ModelLocation = tensorFlowModel.ModelPath;
            options.InputColumns = inputColumnName;
            options.OutputColumns = outputColumnNames;
            options.AddBatchDimensionInputs = addBatchDimensionInput;
            return options;
        }
 
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            var resultDic = inputSchema.ToDictionary(x => x.Name);
            for (var i = 0; i < _options.InputColumns.Length; i++)
            {
                var input = _options.InputColumns[i];
                if (!inputSchema.TryFindColumn(input, out var col))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
                if (!(col.Kind == SchemaShape.Column.VectorKind.Vector))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
                var expectedType = Tf2MlNetType(_tfInputTypes[i]);
                if (col.ItemType != expectedType)
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
            }
            for (var i = 0; i < _options.OutputColumns.Length; i++)
            {
                resultDic[_options.OutputColumns[i]] = new SchemaShape.Column(_options.OutputColumns[i],
                    _outputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector
                    : SchemaShape.Column.VectorKind.VariableVector, _outputTypes[i].GetItemType(), false);
            }
            return new SchemaShape(resultDic.Values);
        }
 
        /// <summary>
        /// Trains and returns a <see cref="DnnRetrainTransformer"/>.
        /// </summary>
        public DnnRetrainTransformer Fit(IDataView input)
        {
            _host.CheckValue(input, nameof(input));
            if (_transformer == null)
                _transformer = new DnnRetrainTransformer(_host, _options, _tensorFlowModel, input);
 
            // Validate input schema.
            _transformer.GetOutputSchema(input.Schema);
            return _transformer;
        }
    }
}