File: ColumnInferenceTests.cs
Web Access
Project: src\test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj (Microsoft.ML.AutoML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using ApprovalTests;
using ApprovalTests.Namers;
using ApprovalTests.Reporters;
using FluentAssertions;
using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Newtonsoft.Json;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.AutoML.Test
{
 
    public class ColumnInferenceTests : BaseTestClass
    {
        public ColumnInferenceTests(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        public void UnGroupReturnsMoreColumnsThanGroup()
        {
            var dataPath = DatasetUtil.GetUciAdultDataset();
            var context = new MLContext(1);
            var columnInferenceWithoutGrouping = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel, groupColumns: false);
            foreach (var col in columnInferenceWithoutGrouping.TextLoaderOptions.Columns)
            {
                Assert.False(col.Source.Length > 1 || col.Source[0].Min != col.Source[0].Max);
            }
 
            var columnInferenceWithGrouping = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel, groupColumns: true);
            Assert.True(columnInferenceWithGrouping.TextLoaderOptions.Columns.Count() < columnInferenceWithoutGrouping.TextLoaderOptions.Columns.Count());
        }
 
        [Fact]
        public void IncorrectLabelColumnThrows()
        {
            var dataPath = DatasetUtil.GetUciAdultDataset();
            var context = new MLContext(1);
            Assert.Throws<ArgumentException>(new System.Action(() => context.Auto().InferColumns(dataPath, "Junk", groupColumns: false)));
        }
 
        [Fact]
        public void LabelIndexOutOfBoundsThrows()
        {
            Assert.Throws<ArgumentOutOfRangeException>(() => new MLContext(1).Auto().InferColumns(DatasetUtil.GetUciAdultDataset(), 100));
        }
 
        [Fact]
        public void IdentifyLabelColumnThroughIndexWithHeader()
        {
            var result = new MLContext(1).Auto().InferColumns(DatasetUtil.GetUciAdultDataset(), 14, hasHeader: true);
            Assert.True(result.TextLoaderOptions.HasHeader);
            var labelCol = result.TextLoaderOptions.Columns.First(c => c.Source[0].Min == 14 && c.Source[0].Max == 14);
            Assert.Equal("hours-per-week", labelCol.Name);
            Assert.Equal("hours-per-week", result.ColumnInformation.LabelColumnName);
        }
 
        [Fact]
        public void IdentifyLabelColumnThroughIndexWithoutHeader()
        {
            var result = new MLContext(1).Auto().InferColumns(DatasetUtil.GetIrisDataset(), DatasetUtil.IrisDatasetLabelColIndex);
            Assert.False(result.TextLoaderOptions.HasHeader);
            var labelCol = result.TextLoaderOptions.Columns.First(c => c.Source[0].Min == DatasetUtil.IrisDatasetLabelColIndex &&
                c.Source[0].Max == DatasetUtil.IrisDatasetLabelColIndex);
            Assert.Equal(DefaultColumnNames.Label, labelCol.Name);
            Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
        }
 
        [Fact]
        public void DatasetWithEmptyColumn()
        {
            var result = new MLContext(1).Auto().InferColumns(Path.Combine("TestData", "DatasetWithEmptyColumn.txt"), DefaultColumnNames.Label, groupColumns: false);
            var emptyColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "Empty");
            Assert.Equal(DataKind.Single, emptyColumn.DataKind);
        }
 
        [Fact]
        public void DatasetWithBoolColumn()
        {
            var result = new MLContext(1).Auto().InferColumns(Path.Combine("TestData", "BinaryDatasetWithBoolColumn.txt"), DefaultColumnNames.Label);
            Assert.Equal(2, result.TextLoaderOptions.Columns.Count());
 
            var boolColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "Bool");
            var labelColumn = result.TextLoaderOptions.Columns.First(c => c.Name == DefaultColumnNames.Label);
            // ensure non-label Boolean column is detected as R4
            Assert.Equal(DataKind.Single, boolColumn.DataKind);
            Assert.Equal(DataKind.Boolean, labelColumn.DataKind);
 
            // ensure non-label Boolean column is detected as R4
            Assert.Single(result.ColumnInformation.NumericColumnNames);
            Assert.Equal("Bool", result.ColumnInformation.NumericColumnNames.First());
            Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
        }
 
        [Fact]
        public void InferDatasetWithoutHeader()
        {
            var context = new MLContext(1);
            var filePath = Path.Combine("TestData", "DatasetWithoutHeader.txt");
            var columnInfo = new ColumnInformation()
            {
                LabelColumnName = "col0",
                UserIdColumnName = "col1",
                ItemIdColumnName = "col2",
            };
            columnInfo.IgnoredColumnNames.Add("col4");
            var result = ColumnInferenceApi.InferColumns(context, filePath, columnInfo, ',', null, null, false, false, false);
            Assert.Equal(6, result.TextLoaderOptions.Columns.Count());
 
            var labelColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "col0");
            var userColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "col1");
            var itemColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "col2");
            var ignoreColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "col4");
 
            Assert.Equal(DataKind.String, labelColumn.DataKind);
            Assert.Equal(DataKind.Single, userColumn.DataKind);
            Assert.Equal(DataKind.Single, itemColumn.DataKind);
            Assert.Equal(DataKind.Single, ignoreColumn.DataKind);
 
            Assert.Single(result.ColumnInformation.CategoricalColumnNames);
            Assert.Empty(result.ColumnInformation.TextColumnNames);
        }
 
        [Fact]
        public void WhereNameColumnIsOnlyFeature()
        {
            var result = new MLContext(1).Auto().InferColumns(Path.Combine("TestData", "NameColumnIsOnlyFeatureDataset.txt"), DefaultColumnNames.Label);
            Assert.Equal(2, result.TextLoaderOptions.Columns.Count());
 
            var nameColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "Username");
            var labelColumn = result.TextLoaderOptions.Columns.First(c => c.Name == DefaultColumnNames.Label);
            Assert.Equal(DataKind.String, nameColumn.DataKind);
            Assert.Equal(DataKind.Boolean, labelColumn.DataKind);
 
            Assert.Single(result.ColumnInformation.TextColumnNames);
            Assert.Equal("Username", result.ColumnInformation.TextColumnNames.First());
            Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
        }
 
        [Fact]
        public void DefaultColumnNamesInferredCorrectly()
        {
            var result = new MLContext(1).Auto()
                .InferColumns(Path.Combine("TestData", "DatasetWithDefaultColumnNames.txt"),
                new ColumnInformation()
                {
                    LabelColumnName = DefaultColumnNames.Label,
                    ExampleWeightColumnName = DefaultColumnNames.Weight,
                    UserIdColumnName = DefaultColumnNames.User,
                    ItemIdColumnName = DefaultColumnNames.Item,
                },
                groupColumns: false);
 
            Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
            Assert.Equal(DefaultColumnNames.Weight, result.ColumnInformation.ExampleWeightColumnName);
            Assert.Equal(DefaultColumnNames.User, result.ColumnInformation.UserIdColumnName);
            Assert.Equal(DefaultColumnNames.Item, result.ColumnInformation.ItemIdColumnName);
            Assert.Equal(3, result.ColumnInformation.NumericColumnNames.Count());
        }
 
        [Fact]
        public void DefaultColumnNamesNoGrouping()
        {
            var result = new MLContext(1).Auto().InferColumns(Path.Combine("TestData", "DatasetWithDefaultColumnNames.txt"),
                new ColumnInformation()
                {
                    LabelColumnName = DefaultColumnNames.Label,
                    ExampleWeightColumnName = DefaultColumnNames.Weight,
                });
 
            Assert.Equal(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName);
            Assert.Equal(DefaultColumnNames.Weight, result.ColumnInformation.ExampleWeightColumnName);
            Assert.Single(result.ColumnInformation.NumericColumnNames);
            Assert.Equal(DefaultColumnNames.Features, result.ColumnInformation.NumericColumnNames.First());
        }
 
        [Fact]
        public void InferColumnsColumnInfoParam()
        {
            var columnInfo = new ColumnInformation() { LabelColumnName = DatasetUtil.MlNetGeneratedRegressionLabel };
            var result = new MLContext(1).Auto().InferColumns(DatasetUtil.GetMlNetGeneratedRegressionDataset(),
                columnInfo);
            var labelCol = result.TextLoaderOptions.Columns.First(c => c.Name == DatasetUtil.MlNetGeneratedRegressionLabel);
            Assert.Equal(DataKind.Single, labelCol.DataKind);
            Assert.Equal(DatasetUtil.MlNetGeneratedRegressionLabel, result.ColumnInformation.LabelColumnName);
            Assert.Single(result.ColumnInformation.NumericColumnNames);
            Assert.Equal(DefaultColumnNames.Features, result.ColumnInformation.NumericColumnNames.First());
            Assert.Null(result.ColumnInformation.ExampleWeightColumnName);
        }
 
        [Fact]
        public void TrySplitColumns_should_split_on_dataset_with_newline_between_double_quotes()
        {
            var context = new MLContext();
            var dataset = Path.Combine("TestData", "DatasetWithNewlineBetweenQuotes.txt");
            var sample = TextFileSample.CreateFromFullFile(dataset);
            var result = TextFileContents.TrySplitColumns(context, sample, TextFileContents.DefaultSeparators);
 
            result.ColumnCount.Should().Be(4);
            result.Separator.Should().Be(',');
            result.IsSuccess.Should().BeTrue();
        }
 
        [Fact]
        public void InferColumnsFromMultilineInputFile()
        {
            // Check if we can infer the column information
            // from and input file which has escaped newlines inside quotes
            var dataPath = GetDataPath("multiline.csv");
            MLContext mlContext = new MLContext();
            var inputColumnInformation = new ColumnInformation();
            inputColumnInformation.LabelColumnName = @"id";
            var result = mlContext.Auto().InferColumns(dataPath, inputColumnInformation);
 
            // File has 3 columns: "id", "description" and "animal"
            Assert.NotNull(result.ColumnInformation.LabelColumnName);
            Assert.Single(result.ColumnInformation.TextColumnNames);
            Assert.Single(result.ColumnInformation.CategoricalColumnNames);
 
            Assert.Equal("id", result.ColumnInformation.LabelColumnName);
            Assert.Equal("description", result.ColumnInformation.TextColumnNames.First());
            Assert.Equal("animal", result.ColumnInformation.CategoricalColumnNames.First());
        }
 
        [Fact]
        [UseReporter(typeof(DiffReporter))]
        [UseApprovalSubdirectory("ApprovalTests")]
        public void Wiki_column_inference_result_should_be_serializable()
        {
            // DiffEngine can't check for Helix, so the environment variable checks for helix.
            if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
            {
                Approvals.UseAssemblyLocationForApprovedFiles();
            }
 
            var wiki = Path.Combine("TestData", "wiki-column-inference.json");
            using (var stream = new StreamReader(wiki))
            {
                var json = stream.ReadToEnd();
                var columnInferenceResults = JsonConvert.DeserializeObject<ColumnInferenceResults>(json);
                Approvals.Verify(JsonConvert.SerializeObject(columnInferenceResults, Formatting.Indented));
            }
        }
    }
}