File: Dynamic\Trainers\MulticlassClassification\ImageClassification\ResnetV2101TransferLearningTrainTestSplit.cs
Web Access
Project: src\docs\samples\Microsoft.ML.Samples\Microsoft.ML.Samples.csproj (Microsoft.ML.Samples)

using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Vision;
using static Microsoft.ML.DataOperationsCatalog;
 
namespace Samples.Dynamic
{
    public class ResnetV2101TransferLearningTrainTestSplit
    {
        public static void Example()
        {
            // Set the path for input images.
            string assetsRelativePath = @"../../../assets";
            string assetsPath = GetAbsolutePath(assetsRelativePath);
 
            string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs",
                "images");
 
            //Download the image set and unzip, set the path to image folder.
            string finalImagesFolderName = DownloadImageSet(
                imagesDownloadFolderPath);
            string fullImagesetFolderPath = Path.Combine(
                imagesDownloadFolderPath, finalImagesFolderName);
 
            MLContext mlContext = new MLContext(seed: 1);
 
            // Load all the original images info.
            IEnumerable<ImageData> images = LoadImagesFromDirectory(
                folder: fullImagesetFolderPath, useFolderNameAsLabel: true);
 
            // Shuffle images.
            IDataView shuffledFullImagesDataset = mlContext.Data.ShuffleRows(
                mlContext.Data.LoadFromEnumerable(images));
 
            // Apply transforms to the input dataset:
            // MapValueToKey : map 'string' type labels to keys
            // LoadImages : load raw images to "Image" column
            shuffledFullImagesDataset = mlContext.Transforms.Conversion
                    .MapValueToKey("Label", keyOrdinality: Microsoft.ML.Transforms
                    .ValueToKeyMappingEstimator.KeyOrdinality.ByValue)
                .Append(mlContext.Transforms.LoadRawImageBytes("Image",
                            fullImagesetFolderPath, "ImagePath"))
                .Fit(shuffledFullImagesDataset)
                .Transform(shuffledFullImagesDataset);
 
            // Split the data 90:10 into train and test sets.
            TrainTestData trainTestData = mlContext.Data.TrainTestSplit(
                shuffledFullImagesDataset, testFraction: 0.1, seed: 1);
 
            IDataView trainDataset = trainTestData.TrainSet;
            IDataView testDataset = trainTestData.TestSet;
 
            // Set the options for ImageClassification.
            var options = new ImageClassificationTrainer.Options()
            {
                FeatureColumnName = "Image",
                LabelColumnName = "Label",
                // Just by changing/selecting InceptionV3/MobilenetV2/ResnetV250
                // here instead of ResnetV2101 you can try a different 
                // architecture/ pre-trained model. 
                Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
                Epoch = 50,
                BatchSize = 10,
                LearningRate = 0.01f,
                MetricsCallback = (metrics) => Console.WriteLine(metrics),
                ValidationSet = testDataset,
                // Disable EarlyStopping to run to specified number of epochs.
                EarlyStoppingCriteria = null
            };
 
            // Create the ImageClassification pipeline.
            var pipeline = mlContext.MulticlassClassification.Trainers.
                ImageClassification(options)
                .Append(mlContext.Transforms.Conversion.MapKeyToValue(
                    outputColumnName: "PredictedLabel",
                    inputColumnName: "PredictedLabel"));
 
 
            Console.WriteLine("*** Training the image classification model " +
                "with DNN Transfer Learning on top of the selected " +
                "pre-trained model/architecture ***");
 
            // Train the model
            // This involves calculating the bottleneck values, and then
            // training the final layer. Sample output is: 
            // Phase: Bottleneck Computation, Dataset used: Train, Image Index:   1
            // Phase: Bottleneck Computation, Dataset used: Train, Image Index:   2
            // ...
            // Phase: Training, Dataset used:      Train, Batch Processed Count:  18, Learning Rate:       0.01 Epoch:   0, Accuracy:  0.9166667, Cross-Entropy:  0.4787029
            // ...
            // Phase: Training, Dataset used:      Train, Batch Processed Count:  18, Learning Rate: 0.002265001 Epoch:  49, Accuracy:          1, Cross-Entropy: 0.03529362
            // Phase: Training, Dataset used: Validation, Batch Processed Count:   3, Epoch:  49, Accuracy:  0.8238096
            var trainedModel = pipeline.Fit(trainDataset);
 
            Console.WriteLine("Training with transfer learning finished.");
 
            // Save the trained model.
            mlContext.Model.Save(trainedModel, shuffledFullImagesDataset.Schema,
                "model.zip");
 
            // Load the trained and saved model for prediction.
            ITransformer loadedModel;
            DataViewSchema schema;
            using (var file = File.OpenRead("model.zip"))
                loadedModel = mlContext.Model.Load(file, out schema);
 
            // Evaluate the model on the test dataset.
            // Sample output:
            // Making bulk predictions and evaluating model's quality...
            // Micro - accuracy: 0.851851851851852,macro - accuracy = 0.85
            EvaluateModel(mlContext, testDataset, loadedModel);
 
            // Predict on a single image class using an in-memory image.
            // Sample output:
            // Scores : [0.9305387,0.005769793,0.0001719091,0.06005093,0.003468738], Predicted Label : daisy
            TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel);
 
            Console.WriteLine("Prediction on a single image finished.");
 
            Console.WriteLine("Press any key to finish");
            Console.ReadKey();
        }
 
        // Predict on a single image.
        private static void TrySinglePrediction(string imagesForPredictions,
            MLContext mlContext, ITransformer trainedModel)
        {
            // Create prediction function to try one prediction.
            var predictionEngine = mlContext.Model
                .CreatePredictionEngine<InMemoryImageData,
                ImagePrediction>(trainedModel);
 
            // Load test images.
            IEnumerable<InMemoryImageData> testImages =
                LoadInMemoryImagesFromDirectory(imagesForPredictions, false);
 
            // Create an in-memory image object from the first image in the test data.
            InMemoryImageData imageToPredict = new InMemoryImageData
            {
                Image = testImages.First().Image
            };
 
            // Predict on the single image.
            var prediction = predictionEngine.Predict(imageToPredict);
 
            Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " +
                $"Predicted Label : {prediction.PredictedLabel}");
        }
 
        // Evaluate the trained model on the passed test dataset.
        private static void EvaluateModel(MLContext mlContext,
            IDataView testDataset, ITransformer trainedModel)
        {
            Console.WriteLine("Making bulk predictions and evaluating model's " +
                "quality...");
 
            // Evaluate the model on the test data and get the evaluation metrics.
            IDataView predictions = trainedModel.Transform(testDataset);
            var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
 
            Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," +
                              $"macro-accuracy = {metrics.MacroAccuracy}");
 
            Console.WriteLine("Predicting and Evaluation complete.");
        }
 
        //Load the Image Data from input directory.
        public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder,
            bool useFolderNameAsLabel = true)
        {
            var files = Directory.GetFiles(folder, "*",
                searchOption: SearchOption.AllDirectories);
            foreach (var file in files)
            {
                if (Path.GetExtension(file) != ".jpg")
                    continue;
 
                var label = Path.GetFileName(file);
                if (useFolderNameAsLabel)
                    label = Directory.GetParent(file).Name;
                else
                {
                    for (int index = 0; index < label.Length; index++)
                    {
                        if (!char.IsLetter(label[index]))
                        {
                            label = label.Substring(0, index);
                            break;
                        }
                    }
                }
 
                yield return new ImageData()
                {
                    ImagePath = file,
                    Label = label
                };
 
            }
        }
 
        // Load In memory raw images from directory.
        public static IEnumerable<InMemoryImageData>
            LoadInMemoryImagesFromDirectory(string folder,
                bool useFolderNameAsLabel = true)
        {
            var files = Directory.GetFiles(folder, "*",
                searchOption: SearchOption.AllDirectories);
            foreach (var file in files)
            {
                if (Path.GetExtension(file) != ".jpg")
                    continue;
 
                var label = Path.GetFileName(file);
                if (useFolderNameAsLabel)
                    label = Directory.GetParent(file).Name;
                else
                {
                    for (int index = 0; index < label.Length; index++)
                    {
                        if (!char.IsLetter(label[index]))
                        {
                            label = label.Substring(0, index);
                            break;
                        }
                    }
                }
 
                yield return new InMemoryImageData()
                {
                    Image = File.ReadAllBytes(file),
                    Label = label
                };
 
            }
        }
 
        // Download and unzip the image dataset.
        public static string DownloadImageSet(string imagesDownloadFolder)
        {
            // get a set of images to teach the network about the new classes
 
            //SINGLE SMALL FLOWERS IMAGESET (200 files)
            string fileName = "flower_photos_small_set.zip";
            string url = $"https://aka.ms/mlnet-resources/datasets/flower_photos_small_set.zip";
 
            Download(url, imagesDownloadFolder, fileName).Wait();
            UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder);
 
            return Path.GetFileNameWithoutExtension(fileName);
        }
 
        // Download file to destination directory from input URL.
        public static async Task<bool> Download(string url, string destDir, string destFileName)
        {
            if (destFileName == null)
                destFileName = url.Split(Path.DirectorySeparatorChar).Last();
 
            Directory.CreateDirectory(destDir);
 
            string relativeFilePath = Path.Combine(destDir, destFileName);
 
            if (File.Exists(relativeFilePath))
            {
                Console.WriteLine($"{relativeFilePath} already exists.");
                return false;
            }
 
            Console.WriteLine($"Downloading {relativeFilePath}");
 
            using (HttpClient client = new HttpClient())
            {
                var response = await client.GetStreamAsync(new Uri($"{url}")).ConfigureAwait(false);
 
                using (var fs = new FileStream(relativeFilePath, FileMode.CreateNew))
                {
                    await response.CopyToAsync(fs);
                }
            }
 
            Console.WriteLine("");
            Console.WriteLine($"Downloaded {relativeFilePath}");
 
            return true;
        }
 
        // Unzip the file to destination folder.
        public static void UnZip(String gzArchiveName, String destFolder)
        {
            var flag = gzArchiveName.Split(Path.DirectorySeparatorChar)
                .Last()
                .Split('.')
                .First() + ".bin";
 
            if (File.Exists(Path.Combine(destFolder, flag))) return;
 
            Console.WriteLine($"Extracting.");
            var task = Task.Run(() =>
            {
                ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
            });
 
            while (!task.IsCompleted)
            {
                Thread.Sleep(200);
                Console.Write(".");
            }
 
            File.Create(Path.Combine(destFolder, flag));
            Console.WriteLine("");
            Console.WriteLine("Extracting is completed.");
        }
 
        // Get absolute path from relative path.
        public static string GetAbsolutePath(string relativePath)
        {
            FileInfo _dataRoot = new FileInfo(typeof(
                ResnetV2101TransferLearningTrainTestSplit).Assembly.Location);
 
            string assemblyFolderPath = _dataRoot.Directory.FullName;
 
            string fullPath = Path.Combine(assemblyFolderPath, relativePath);
 
            return fullPath;
        }
 
        // InMemoryImageData class holding the raw image byte array and label.
        public class InMemoryImageData
        {
            [LoadColumn(0)]
            public byte[] Image;
 
            [LoadColumn(1)]
            public string Label;
        }
 
        // ImageData class holding the image path and label.
        public class ImageData
        {
            [LoadColumn(0)]
            public string ImagePath;
 
            [LoadColumn(1)]
            public string Label;
        }
 
        // ImagePrediction class holding the score and predicted label metrics.
        public class ImagePrediction
        {
            [ColumnName("Score")]
            public float[] Score;
 
            [ColumnName("PredictedLabel")]
            public string PredictedLabel;
        }
    }
}