using System; using System.Collections.Generic; using System.IO; using System.IO.Compression; using System.Linq; using System.Net; using System.Net.Http; using System.Threading; using System.Threading.Tasks; using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Vision; using static Microsoft.ML.DataOperationsCatalog; namespace Samples.Dynamic { public class ResnetV2101TransferLearningTrainTestSplit { public static void Example() { // Set the path for input images. string assetsRelativePath = @"../../../assets"; string assetsPath = GetAbsolutePath(assetsRelativePath); string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs", "images"); //Download the image set and unzip, set the path to image folder. string finalImagesFolderName = DownloadImageSet( imagesDownloadFolderPath); string fullImagesetFolderPath = Path.Combine( imagesDownloadFolderPath, finalImagesFolderName); MLContext mlContext = new MLContext(seed: 1); // Load all the original images info. IEnumerable<ImageData> images = LoadImagesFromDirectory( folder: fullImagesetFolderPath, useFolderNameAsLabel: true); // Shuffle images. IDataView shuffledFullImagesDataset = mlContext.Data.ShuffleRows( mlContext.Data.LoadFromEnumerable(images)); // Apply transforms to the input dataset: // MapValueToKey : map 'string' type labels to keys // LoadImages : load raw images to "Image" column shuffledFullImagesDataset = mlContext.Transforms.Conversion .MapValueToKey("Label", keyOrdinality: Microsoft.ML.Transforms .ValueToKeyMappingEstimator.KeyOrdinality.ByValue) .Append(mlContext.Transforms.LoadRawImageBytes("Image", fullImagesetFolderPath, "ImagePath")) .Fit(shuffledFullImagesDataset) .Transform(shuffledFullImagesDataset); // Split the data 90:10 into train and test sets. TrainTestData trainTestData = mlContext.Data.TrainTestSplit( shuffledFullImagesDataset, testFraction: 0.1, seed: 1); IDataView trainDataset = trainTestData.TrainSet; IDataView testDataset = trainTestData.TestSet; // Set the options for ImageClassification. var options = new ImageClassificationTrainer.Options() { FeatureColumnName = "Image", LabelColumnName = "Label", // Just by changing/selecting InceptionV3/MobilenetV2/ResnetV250 // here instead of ResnetV2101 you can try a different // architecture/ pre-trained model. Arch = ImageClassificationTrainer.Architecture.ResnetV2101, Epoch = 50, BatchSize = 10, LearningRate = 0.01f, MetricsCallback = (metrics) => Console.WriteLine(metrics), ValidationSet = testDataset, // Disable EarlyStopping to run to specified number of epochs. EarlyStoppingCriteria = null }; // Create the ImageClassification pipeline. var pipeline = mlContext.MulticlassClassification.Trainers. ImageClassification(options) .Append(mlContext.Transforms.Conversion.MapKeyToValue( outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")); Console.WriteLine("*** Training the image classification model " + "with DNN Transfer Learning on top of the selected " + "pre-trained model/architecture ***"); // Train the model // This involves calculating the bottleneck values, and then // training the final layer. Sample output is: // Phase: Bottleneck Computation, Dataset used: Train, Image Index: 1 // Phase: Bottleneck Computation, Dataset used: Train, Image Index: 2 // ... // Phase: Training, Dataset used: Train, Batch Processed Count: 18, Learning Rate: 0.01 Epoch: 0, Accuracy: 0.9166667, Cross-Entropy: 0.4787029 // ... // Phase: Training, Dataset used: Train, Batch Processed Count: 18, Learning Rate: 0.002265001 Epoch: 49, Accuracy: 1, Cross-Entropy: 0.03529362 // Phase: Training, Dataset used: Validation, Batch Processed Count: 3, Epoch: 49, Accuracy: 0.8238096 var trainedModel = pipeline.Fit(trainDataset); Console.WriteLine("Training with transfer learning finished."); // Save the trained model. mlContext.Model.Save(trainedModel, shuffledFullImagesDataset.Schema, "model.zip"); // Load the trained and saved model for prediction. ITransformer loadedModel; DataViewSchema schema; using (var file = File.OpenRead("model.zip")) loadedModel = mlContext.Model.Load(file, out schema); // Evaluate the model on the test dataset. // Sample output: // Making bulk predictions and evaluating model's quality... // Micro - accuracy: 0.851851851851852,macro - accuracy = 0.85 EvaluateModel(mlContext, testDataset, loadedModel); // Predict on a single image class using an in-memory image. // Sample output: // Scores : [0.9305387,0.005769793,0.0001719091,0.06005093,0.003468738], Predicted Label : daisy TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel); Console.WriteLine("Prediction on a single image finished."); Console.WriteLine("Press any key to finish"); Console.ReadKey(); } // Predict on a single image. private static void TrySinglePrediction(string imagesForPredictions, MLContext mlContext, ITransformer trainedModel) { // Create prediction function to try one prediction. var predictionEngine = mlContext.Model .CreatePredictionEngine<InMemoryImageData, ImagePrediction>(trainedModel); // Load test images. IEnumerable<InMemoryImageData> testImages = LoadInMemoryImagesFromDirectory(imagesForPredictions, false); // Create an in-memory image object from the first image in the test data. InMemoryImageData imageToPredict = new InMemoryImageData { Image = testImages.First().Image }; // Predict on the single image. var prediction = predictionEngine.Predict(imageToPredict); Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " + $"Predicted Label : {prediction.PredictedLabel}"); } // Evaluate the trained model on the passed test dataset. private static void EvaluateModel(MLContext mlContext, IDataView testDataset, ITransformer trainedModel) { Console.WriteLine("Making bulk predictions and evaluating model's " + "quality..."); // Evaluate the model on the test data and get the evaluation metrics. IDataView predictions = trainedModel.Transform(testDataset); var metrics = mlContext.MulticlassClassification.Evaluate(predictions); Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," + $"macro-accuracy = {metrics.MacroAccuracy}"); Console.WriteLine("Predicting and Evaluation complete."); } //Load the Image Data from input directory. public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder, bool useFolderNameAsLabel = true) { var files = Directory.GetFiles(folder, "*", searchOption: SearchOption.AllDirectories); foreach (var file in files) { if (Path.GetExtension(file) != ".jpg") continue; var label = Path.GetFileName(file); if (useFolderNameAsLabel) label = Directory.GetParent(file).Name; else { for (int index = 0; index < label.Length; index++) { if (!char.IsLetter(label[index])) { label = label.Substring(0, index); break; } } } yield return new ImageData() { ImagePath = file, Label = label }; } } // Load In memory raw images from directory. public static IEnumerable<InMemoryImageData> LoadInMemoryImagesFromDirectory(string folder, bool useFolderNameAsLabel = true) { var files = Directory.GetFiles(folder, "*", searchOption: SearchOption.AllDirectories); foreach (var file in files) { if (Path.GetExtension(file) != ".jpg") continue; var label = Path.GetFileName(file); if (useFolderNameAsLabel) label = Directory.GetParent(file).Name; else { for (int index = 0; index < label.Length; index++) { if (!char.IsLetter(label[index])) { label = label.Substring(0, index); break; } } } yield return new InMemoryImageData() { Image = File.ReadAllBytes(file), Label = label }; } } // Download and unzip the image dataset. public static string DownloadImageSet(string imagesDownloadFolder) { // get a set of images to teach the network about the new classes //SINGLE SMALL FLOWERS IMAGESET (200 files) string fileName = "flower_photos_small_set.zip"; string url = $"https://aka.ms/mlnet-resources/datasets/flower_photos_small_set.zip"; Download(url, imagesDownloadFolder, fileName).Wait(); UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder); return Path.GetFileNameWithoutExtension(fileName); } // Download file to destination directory from input URL. public static async Task<bool> Download(string url, string destDir, string destFileName) { if (destFileName == null) destFileName = url.Split(Path.DirectorySeparatorChar).Last(); Directory.CreateDirectory(destDir); string relativeFilePath = Path.Combine(destDir, destFileName); if (File.Exists(relativeFilePath)) { Console.WriteLine($"{relativeFilePath} already exists."); return false; } Console.WriteLine($"Downloading {relativeFilePath}"); using (HttpClient client = new HttpClient()) { var response = await client.GetStreamAsync(new Uri($"{url}")).ConfigureAwait(false); using (var fs = new FileStream(relativeFilePath, FileMode.CreateNew)) { await response.CopyToAsync(fs); } } Console.WriteLine(""); Console.WriteLine($"Downloaded {relativeFilePath}"); return true; } // Unzip the file to destination folder. public static void UnZip(String gzArchiveName, String destFolder) { var flag = gzArchiveName.Split(Path.DirectorySeparatorChar) .Last() .Split('.') .First() + ".bin"; if (File.Exists(Path.Combine(destFolder, flag))) return; Console.WriteLine($"Extracting."); var task = Task.Run(() => { ZipFile.ExtractToDirectory(gzArchiveName, destFolder); }); while (!task.IsCompleted) { Thread.Sleep(200); Console.Write("."); } File.Create(Path.Combine(destFolder, flag)); Console.WriteLine(""); Console.WriteLine("Extracting is completed."); } // Get absolute path from relative path. public static string GetAbsolutePath(string relativePath) { FileInfo _dataRoot = new FileInfo(typeof( ResnetV2101TransferLearningTrainTestSplit).Assembly.Location); string assemblyFolderPath = _dataRoot.Directory.FullName; string fullPath = Path.Combine(assemblyFolderPath, relativePath); return fullPath; } // InMemoryImageData class holding the raw image byte array and label. public class InMemoryImageData { [LoadColumn(0)] public byte[] Image; [LoadColumn(1)] public string Label; } // ImageData class holding the image path and label. public class ImageData { [LoadColumn(0)] public string ImagePath; [LoadColumn(1)] public string Label; } // ImagePrediction class holding the score and predicted label metrics. public class ImagePrediction { [ColumnName("Score")] public float[] Score; [ColumnName("PredictedLabel")] public string PredictedLabel; } } } |