File: DatasetUtil.cs
Web Access
Project: src\test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj (Microsoft.ML.AutoML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.ML.Data;
using Microsoft.ML.TestFrameworkCommon;
 
namespace Microsoft.ML.AutoML.Test
{
    internal static class DatasetUtil
    {
        public const string UciAdultLabel = DefaultColumnNames.Label;
        public const string TaxiFareLabel = "fare_amount";
        public const string TrivialMulticlassDatasetLabel = "Target";
        public const string MlNetGeneratedRegressionLabel = "target";
        public const string NewspaperChurnLabel = "Subscriber";
        public const int IrisDatasetLabelColIndex = 0;
 
        public static string TrivialMulticlassDatasetPath = Path.Combine("TestData", "TrivialMulticlassDataset.txt");
 
        private static IDataView _uciAdultDataView;
 
        private static IDataView _taxiFareTrainDataView;
 
        private static IDataView _taxiFareTestDataView;
 
        private static IDataView _irisDataView;
 
        private static IDataView _newspaperChurnDataView;
 
        public static string GetUciAdultDataset() => GetDataPath("adult.tiny.with-schema.txt");
 
        public static string GetMlNetGeneratedRegressionDataset() => GetDataPath("generated_regression_dataset.csv");
 
        public static string GetIrisDataset() => GetDataPath("iris.txt");
 
        public static string GetMLSRDataset() => GetDataPath("MSLRWeb1K-tiny.tsv");
 
        public static string GetDataPath(string fileName)
        {
            return Path.Combine(TestCommon.GetRepoRoot(), "test", "data", fileName);
        }
 
        public static IDataView GetUciAdultDataView()
        {
            if (_uciAdultDataView == null)
            {
                var context = new MLContext(1);
                var uciAdultDataFile = GetUciAdultDataset();
                var columnInferenceResult = context.Auto().InferColumns(uciAdultDataFile, UciAdultLabel);
                var textLoader = context.Data.CreateTextLoader(columnInferenceResult.TextLoaderOptions);
                _uciAdultDataView = textLoader.Load(uciAdultDataFile);
            }
            return _uciAdultDataView;
        }
 
        public static IDataView GetIrisDataView()
        {
            if (_irisDataView == null)
            {
                var context = new MLContext(1);
                var dataFile = GetIrisDataset();
                var columnInferenceResult = context.Auto().InferColumns(dataFile, 0, groupColumns: false);
                var textLoader = context.Data.CreateTextLoader(columnInferenceResult.TextLoaderOptions);
                _irisDataView = textLoader.Load(dataFile);
            }
            return _irisDataView;
        }
 
        public static IDataView GetTaxiFareTrainDataView()
        {
            if (_taxiFareTrainDataView == null)
            {
                var context = new MLContext(1);
                var taxiFareFile = GetDataPath("taxi-fare-train.csv");
                var columnInferenceResult = context.Auto().InferColumns(taxiFareFile, TaxiFareLabel);
                var textLoader = context.Data.CreateTextLoader(columnInferenceResult.TextLoaderOptions);
                _taxiFareTrainDataView = textLoader.Load(taxiFareFile);
            }
            return _taxiFareTrainDataView;
        }
 
        public static IDataView GetNewspaperChurnDataView()
        {
            if (_newspaperChurnDataView == null)
            {
                var context = new MLContext(1);
                var file = GetDataPath("newspaperchurn.csv");
                var columnInferenceResult = context.Auto().InferColumns(file, NewspaperChurnLabel);
                var textLoader = context.Data.CreateTextLoader(columnInferenceResult.TextLoaderOptions);
                _newspaperChurnDataView = textLoader.Load(file);
            }
 
            return _newspaperChurnDataView;
        }
 
        public static IDataView GetCreditApprovalDataView()
        {
            var context = new MLContext(1);
            var file = GetDataPath(@"creditapproval_train.csv");
            var columnInferenceResult = context.Auto().InferColumns(file, "A16");
            var textLoader = context.Data.CreateTextLoader(columnInferenceResult.TextLoaderOptions);
            return textLoader.Load(file);
        }
 
        public static IDataView GetTaxiFareTestDataView()
        {
            if (_taxiFareTestDataView == null)
            {
                var context = new MLContext(1);
                var taxiFareFile = GetDataPath("taxi-fare-test.csv");
                var columnInferenceResult = context.Auto().InferColumns(taxiFareFile, TaxiFareLabel);
                var textLoader = context.Data.CreateTextLoader(columnInferenceResult.TextLoaderOptions);
                _taxiFareTestDataView = textLoader.Load(taxiFareFile);
            }
            return _taxiFareTestDataView;
        }
 
        public static string GetFlowersDataset()
        {
            const string datasetName = @"flowers";
            string assetsRelativePath = @"assets";
            string assetsPath = GetAbsolutePath(assetsRelativePath);
            string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs",
                "images");
 
            //Download the image set and unzip
            string finalImagesFolderName = DownloadImageSet(
                imagesDownloadFolderPath);
 
            string fullImagesetFolderPath = Path.Combine(
                imagesDownloadFolderPath, finalImagesFolderName);
 
            var images = LoadImagesFromDirectory(folder: fullImagesetFolderPath);
 
            using (StreamWriter file = new StreamWriter(datasetName))
            {
                file.WriteLine("Label,ImagePath");
                foreach (var image in images)
                    file.WriteLine(image.Label + "," + image.ImagePath);
            }
 
            return datasetName;
        }
 
        public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder)
        {
            var files = Directory.GetFiles(folder, "*",
                searchOption: SearchOption.AllDirectories);
            /*
             * This is only needed as Linux can produce files in a different 
             * order than other OSes. As this is a test case we want to maintain
             * consistent accuracy across all OSes, so we sort to remove this discrepancy.
             */
            Array.Sort(files);
            foreach (var file in files)
            {
                var extension = Path.GetExtension(file).ToLower();
                if (extension != ".jpg" &&
                    extension != ".jpeg" &&
                    extension != ".png" &&
                    extension != ".gif"
                )
                    continue;
 
                var label = Path.GetFileName(file);
                label = Directory.GetParent(file).Name;
                yield return new ImageData()
                {
                    ImagePath = file,
                    Label = label
                };
 
            }
        }
 
        public static string DownloadImageSet(string imagesDownloadFolder)
        {
            string fileName = "flower_photos_tiny_set_for_unit_tests.zip";
            string url = $"https://aka.ms/mlnet-resources/datasets/flower_photos_tiny_set_for_unit_test.zip";
 
            Download(url, imagesDownloadFolder, fileName).Wait();
            UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder);
 
            return Path.GetFileNameWithoutExtension(fileName);
        }
 
        private static async Task Download(string url, string destDir, string destFileName)
        {
            if (destFileName == null)
                destFileName = Path.GetFileName(new Uri(url).AbsolutePath); ;
 
            Directory.CreateDirectory(destDir);
 
            string relativeFilePath = Path.Combine(destDir, destFileName);
 
            if (File.Exists(relativeFilePath))
                return;
 
            using (var client = new HttpClient())
            {
                var response = await client.GetAsync(url).ConfigureAwait(false);
                var stream = await response.EnsureSuccessStatusCode().Content.ReadAsStreamAsync().ConfigureAwait(false);
                var fileInfo = new FileInfo(relativeFilePath);
                using (var fileStream = fileInfo.OpenWrite())
                {
                    await stream.CopyToAsync(fileStream).ConfigureAwait(false);
                }
            }
            return;
        }
 
        private static void UnZip(String gzArchiveName, String destFolder)
        {
            var flag = gzArchiveName.Split(Path.DirectorySeparatorChar)
                .Last()
                .Split('.')
                .First() + ".bin";
 
            if (File.Exists(Path.Combine(destFolder, flag)))
                return;
 
            ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
            File.Create(Path.Combine(destFolder, flag));
        }
 
        public static string GetAbsolutePath(string relativePath) =>
            Path.Combine(new FileInfo(typeof(
                DatasetUtil).Assembly.Location).Directory.FullName, relativePath);
 
        public class ImageData
        {
            [LoadColumn(0)]
            public string ImagePath;
 
            [LoadColumn(1)]
            public string Label;
        }
    }
}