File: VisionCatalog.cs
Web Access
Project: src\src\Microsoft.ML.Vision\Microsoft.ML.Vision.csproj (Microsoft.ML.Vision)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
using Microsoft.ML.Vision;
using static Microsoft.ML.TensorFlow.TensorFlowUtils;
 
namespace Microsoft.ML
{
    /// <summary>
    /// Collection of extension methods for <see cref="T:Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers" /> to create instances of ImageClassification trainer components.
    /// </summary>
    /// <remarks>
    /// This requires additional nuget dependencies to link against Tensorflow native dlls. See <see cref="T:Microsoft.ML.Vision.ImageClassificationTrainer"/> for more information.
    /// </remarks>
    public static class VisionCatalog
    {
 
        /// <summary>
        /// Retrain the dnn model on new data.
        /// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information.
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)]
        /// ]]>
        /// </format>
        /// </summary>
        /// <param name="catalog"></param>
        /// <param name="inputColumnNames"> The names of the model inputs.</param>
        /// <param name="outputColumnNames">The names of the requested model outputs.</param>
        /// <param name="labelColumnName">Name of the label column.</param>
        /// <param name="dnnLabel">Name of the node in DNN graph that is used as label during training in Dnn.
        /// The value of <paramref name="labelColumnName"/> from <see cref="IDataView"/> is fed to this node.</param>
        /// <param name="optimizationOperation">The name of the optimization operation in the Dnn graph.</param>
        /// <param name="modelPath">Path to model file to retrain.</param>
        /// <param name="epoch">Number of training iterations.</param>
        /// <param name="batchSize">Number of samples to use for mini-batch training.</param>
        /// <param name="lossOperation">The name of the operation in the Dnn graph to compute training loss (Optional).</param>
        /// <param name="metricOperation">The name of the operation in the Dnn graph to compute performance metric during training (Optional).</param>
        /// <param name="learningRateOperation">The name of the operation in the Dnn graph which sets optimizer learning rate (Optional).</param>
        /// <param name="learningRate">Learning rate to use during optimization (Optional).</param>
        /// <param name="addBatchDimensionInput">Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
        /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.</param>
        /// <remarks>
        /// The support for retraining is under preview.
        /// </remarks>
        internal static DnnRetrainEstimator RetrainDnnModel(
            this ModelOperationsCatalog catalog,
            string[] outputColumnNames,
            string[] inputColumnNames,
            string labelColumnName,
            string dnnLabel,
            string optimizationOperation,
            string modelPath,
            int epoch = 10,
            int batchSize = 20,
            string lossOperation = null,
            string metricOperation = null,
            string learningRateOperation = null,
            float learningRate = 0.01f,
            bool addBatchDimensionInput = false)
        {
            var options = new DnnRetrainEstimator.Options()
            {
                ModelLocation = modelPath,
                InputColumns = inputColumnNames,
                OutputColumns = outputColumnNames,
                LabelColumn = labelColumnName,
                TensorFlowLabel = dnnLabel,
                OptimizationOperation = optimizationOperation,
                LossOperation = lossOperation,
                MetricOperation = metricOperation,
                Epoch = epoch,
                LearningRateOperation = learningRateOperation,
                LearningRate = learningRate,
                BatchSize = batchSize,
                AddBatchDimensionInputs = addBatchDimensionInput
            };
 
            var env = CatalogUtils.GetEnvironment(catalog);
            return new DnnRetrainEstimator(env, options, LoadDnnModel(env, modelPath, true));
        }
 
        /// <summary>
        /// Create <see cref="Microsoft.ML.Vision.ImageClassificationTrainer"/> using advanced options, which trains a Deep Neural Network(DNN) to classify images.
        /// </summary>
        /// <param name="catalog">Catalog</param>
        /// <param name="options">An <see cref="ImageClassificationTrainer.Options"/> object specifying advanced
        /// options for <see cref="ImageClassificationTrainer"/>.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        ///  [!code-csharp[ImageClassification](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs)]
        /// ]]></format>
        /// </example>
 
        public static ImageClassificationTrainer ImageClassification(
            this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
            ImageClassificationTrainer.Options options) =>
                new ImageClassificationTrainer(CatalogUtils.GetEnvironment(catalog), options);
 
        /// <summary>
        /// Create <see cref="Microsoft.ML.Vision.ImageClassificationTrainer"/>, which trains a Deep Neural Network(DNN) to classify images.
        /// </summary>
        /// <param name="catalog">Catalog</param>
        /// <param name="labelColumnName">The name of the labels column. The default for this parameter is "label".</param>
        /// <param name="featureColumnName">The name of the input features column. The default for this parameter is "Features".</param>
        /// <param name="scoreColumnName">The name of the output score column. The default for this parameter is "Score"</param>
        /// <param name="predictedLabelColumnName">The name of the output predicted label columns. The default for this parameter is "PredictedLabel"</param>
        /// <param name="validationSet">The validation set used while training to improve model quality. The default for this parameter is null.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        ///  [!code-csharp[ImageClassification](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ImageClassificationDefault.cs)]
        ///  ]]></format>
        /// </example>
 
        public static ImageClassificationTrainer ImageClassification(
            this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
            string labelColumnName = DefaultColumnNames.Label,
            string featureColumnName = DefaultColumnNames.Features,
            string scoreColumnName = DefaultColumnNames.Score,
            string predictedLabelColumnName = DefaultColumnNames.PredictedLabel,
            IDataView validationSet = null)
        {
            Contracts.CheckValue(catalog, nameof(catalog));
            return new ImageClassificationTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName,
                featureColumnName, scoreColumnName, predictedLabelColumnName, validationSet);
        }
    }
}