File: UserInputValidationTests.cs
Web Access
Project: src\test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj (Microsoft.ML.AutoML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.AutoML.Test
{
 
    public class UserInputValidationTests : BaseTestClass
    {
        private static readonly IDataView _data = DatasetUtil.GetUciAdultDataView();
 
        public UserInputValidationTests(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        public void ValidateExperimentExecuteNullTrainData()
        {
            var ex = Assert.Throws<ArgumentNullException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(null, new ColumnInformation(), null, TaskKind.Regression));
            Assert.StartsWith("Training data cannot be null", ex.Message);
        }
 
        [Fact]
        public void ValidateExperimentExecuteNullLabel()
        {
            var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
                new ColumnInformation() { LabelColumnName = null }, null, TaskKind.Regression));
 
            Assert.Equal("Provided label column cannot be null", ex.Message);
        }
 
        [Fact]
        public void ValidateExperimentExecuteLabelNotInTrain()
        {
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                const string columnName = "ReallyLongNonExistingColumnName";
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
                new ColumnInformation() { LabelColumnName = columnName }, null, task));
 
                Assert.Equal($"Provided label column '{columnName}' not found in training data.", ex.Message);
            }
        }
 
        [Fact]
        public void ValidateExperimentExecuteLabelNotInTrainMistyped()
        {
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                var originalColumnName = _data.Schema.First().Name;
                var mistypedColumnName = originalColumnName + "a";
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
                    new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task));
 
                Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data. Did you mean '{originalColumnName}'.",
                    ex.Message);
            }
        }
 
        [Fact]
        public void ValidateExperimentExecuteNumericColNotInTrain()
        {
            var columnInfo = new ColumnInformation();
            columnInfo.NumericColumnNames.Add("N");
 
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, columnInfo, null, task));
                Assert.Equal("Provided label column 'Label' was of type Boolean, but only type Single is allowed.", ex.Message);
            }
        }
 
        [Fact]
        public void ValidateExperimentExecuteNullNumericCol()
        {
            var columnInfo = new ColumnInformation();
            columnInfo.NumericColumnNames.Add(null);
 
            var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, columnInfo, null, TaskKind.Regression));
            Assert.Equal("Null column string was specified as numeric in column information", ex.Message);
        }
 
        [Fact]
        public void ValidateExperimentExecuteDuplicateCol()
        {
            var columnInfo = new ColumnInformation();
            columnInfo.NumericColumnNames.Add(DefaultColumnNames.Label);
 
            var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, columnInfo, null, TaskKind.Regression));
        }
 
        [Fact]
        public void ValidateExperimentExecuteArgsTrainValidColCountMismatch()
        {
            var context = new MLContext(1);
 
            var trainDataBuilder = new ArrayDataViewBuilder(context);
            trainDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
            trainDataBuilder.AddColumn("1", new string[] { "1" });
            var trainData = trainDataBuilder.GetDataView();
 
            var validDataBuilder = new ArrayDataViewBuilder(context);
            validDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
            var validData = validDataBuilder.GetDataView();
 
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
                    new ColumnInformation() { LabelColumnName = "0" }, validData, task));
                Assert.StartsWith("Training data and validation data schemas do not match. Train data has '2' columns,and validation data has '1' columns.", ex.Message);
            }
        }
 
        [Fact]
        public void ValidateExperimentExecuteArgsTrainValidColNamesMismatch()
        {
            var context = new MLContext(1);
 
            var trainDataBuilder = new ArrayDataViewBuilder(context);
            trainDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
            trainDataBuilder.AddColumn("1", new string[] { "1" });
            var trainData = trainDataBuilder.GetDataView();
 
            var validDataBuilder = new ArrayDataViewBuilder(context);
            validDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
            validDataBuilder.AddColumn("2", new string[] { "2" });
            var validData = validDataBuilder.GetDataView();
 
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
                    new ColumnInformation() { LabelColumnName = "0" }, validData, task));
                Assert.StartsWith("Training data and validation data schemas do not match. Column '1' exists in train data, but not in validation data.", ex.Message);
            }
        }
 
        [Fact]
        public void ValidateExperimentExecuteArgsTrainValidColTypeMismatch()
        {
            var context = new MLContext(1);
 
            var trainDataBuilder = new ArrayDataViewBuilder(context);
            trainDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
            trainDataBuilder.AddColumn("1", new string[] { "1" });
            var trainData = trainDataBuilder.GetDataView();
 
            var validDataBuilder = new ArrayDataViewBuilder(context);
            validDataBuilder.AddColumn("0", NumberDataViewType.Single, new float[] { 1 });
            validDataBuilder.AddColumn("1", NumberDataViewType.Single, new float[] { 1 });
            var validData = validDataBuilder.GetDataView();
 
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
                    new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression));
                Assert.StartsWith("Training data and validation data schemas do not match. Column '1' is of type String in train data, and type Single in validation data.", ex.Message);
            }
        }
 
        [Fact]
        public void ValidateInferColumnsArgsNullPath()
        {
            var ex = Assert.Throws<ArgumentNullException>(() => UserInputValidationUtil.ValidateInferColumnsArgs(null, "Label"));
            Assert.StartsWith("Provided path cannot be null", ex.Message);
        }
 
        [Fact]
        public void ValidateInferColumnsArgsPathNotExist()
        {
            var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateInferColumnsArgs("idontexist", "Label"));
            Assert.StartsWith("File 'idontexist' does not exist", ex.Message);
        }
 
        [Fact]
        public void ValidateInferColumnsArgsEmptyFile()
        {
            const string emptyFilePath = "empty";
            File.Create(emptyFilePath).Dispose();
            var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateInferColumnsArgs(emptyFilePath, "Label"));
            Assert.StartsWith("File at path 'empty' cannot be empty", ex.Message);
        }
 
        [Fact]
        public void ValidateInferColsPath()
        {
            UserInputValidationUtil.ValidateInferColumnsArgs(DatasetUtil.GetUciAdultDataset());
        }
 
        [Fact]
        public void ValidateFeaturesColInvalidType()
        {
            var schemaBuilder = new DataViewSchema.Builder();
            schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Double);
            schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
            var schema = schemaBuilder.ToSchema();
            var dataView = DataViewTestFixture.BuildDummyDataView(schema);
 
            var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, TaskKind.Regression));
            Assert.StartsWith("Features column must be of data type Single", ex.Message);
        }
 
        [Fact]
        public void ValidateTextColumnNotText()
        {
            const string textPurposeColName = "TextColumn";
            var schemaBuilder = new DataViewSchema.Builder();
            schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
            schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
            schemaBuilder.AddColumn(textPurposeColName, NumberDataViewType.Single);
            var schema = schemaBuilder.ToSchema();
            var dataView = DataViewTestFixture.BuildDummyDataView(schema);
 
            var columnInfo = new ColumnInformation();
            columnInfo.TextColumnNames.Add(textPurposeColName);
 
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, columnInfo, null, task));
                Assert.Equal("Provided text column 'TextColumn' was of type Single, but only type String is allowed.", ex.Message);
            }
        }
 
        [Fact]
        public void ValidateRegressionLabelTypes()
        {
            ValidateLabelTypeTestCore<float>(TaskKind.Regression, NumberDataViewType.Single, true);
            ValidateLabelTypeTestCore<bool>(TaskKind.Regression, BooleanDataViewType.Instance, false);
            ValidateLabelTypeTestCore<double>(TaskKind.Regression, NumberDataViewType.Double, false);
            ValidateLabelTypeTestCore<string>(TaskKind.Regression, TextDataViewType.Instance, false);
        }
 
        [Fact]
        public void ValidateRecommendationLabelTypes()
        {
            ValidateLabelTypeTestCore<float>(TaskKind.Recommendation, NumberDataViewType.Single, true);
            ValidateLabelTypeTestCore<bool>(TaskKind.Recommendation, BooleanDataViewType.Instance, false);
            ValidateLabelTypeTestCore<double>(TaskKind.Recommendation, NumberDataViewType.Double, false);
            ValidateLabelTypeTestCore<string>(TaskKind.Recommendation, TextDataViewType.Instance, false);
        }
 
        [Fact]
        public void ValidateBinaryClassificationLabelTypes()
        {
            ValidateLabelTypeTestCore<float>(TaskKind.BinaryClassification, NumberDataViewType.Single, false);
            ValidateLabelTypeTestCore<bool>(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true);
        }
 
        [Fact]
        public void ValidateMulticlassLabelTypes()
        {
            ValidateLabelTypeTestCore<float>(TaskKind.MulticlassClassification, NumberDataViewType.Single, true);
            ValidateLabelTypeTestCore<bool>(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true);
            ValidateLabelTypeTestCore<double>(TaskKind.MulticlassClassification, NumberDataViewType.Double, true);
            ValidateLabelTypeTestCore<string>(TaskKind.MulticlassClassification, TextDataViewType.Instance, true);
        }
 
        [Fact]
        public void ValidateRankingLabelTypes()
        {
            ValidateLabelTypeTestCore<float>(TaskKind.Ranking, NumberDataViewType.Single, true);
            ValidateLabelTypeTestCore<bool>(TaskKind.Ranking, BooleanDataViewType.Instance, false);
            ValidateLabelTypeTestCore<double>(TaskKind.Ranking, NumberDataViewType.Double, false);
            ValidateLabelTypeTestCore<string>(TaskKind.Ranking, TextDataViewType.Instance, false);
        }
 
        [Fact]
        public void ValidateAllowedFeatureColumnTypes()
        {
            var dataViewBuilder = new ArrayDataViewBuilder(new MLContext(1));
            dataViewBuilder.AddColumn("Boolean", BooleanDataViewType.Instance, false);
            dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
            dataViewBuilder.AddColumn("Text", "a");
            dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
            var dataView = dataViewBuilder.GetDataView();
 
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
                    null, task);
            }
        }
 
        [Fact]
        public void ValidateProhibitedFeatureColumnType()
        {
            var schemaBuilder = new DataViewSchema.Builder();
            schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64);
            schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
            var schema = schemaBuilder.ToSchema();
            var dataView = DataViewTestFixture.BuildDummyDataView(schema);
 
            var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
                null, TaskKind.Regression));
            Assert.StartsWith("Only supported feature column types are Boolean, Single, and String. Please change the feature column UInt64 of type UInt64 to one of the supported types.", ex.Message);
        }
 
        [Fact]
        public void ValidateEmptyTrainingDataThrows()
        {
            var schemaBuilder = new DataViewSchema.Builder();
            schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
            schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
            var schema = schemaBuilder.ToSchema();
            var dataView = DataViewTestFixture.BuildDummyDataView(schema, createDummyRow: false);
            var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
                null, TaskKind.Regression));
            Assert.StartsWith("Training data has 0 rows", ex.Message);
        }
 
        [Fact]
        public void ValidateEmptyValidationDataThrows()
        {
            // Training data
            var dataViewBuilder = new ArrayDataViewBuilder(new MLContext(1));
            dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
            dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
            var trainingData = dataViewBuilder.GetDataView();
 
            // Validation data
            var schemaBuilder = new DataViewSchema.Builder();
            schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
            schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
            var schema = schemaBuilder.ToSchema();
            var validationData = DataViewTestFixture.BuildDummyDataView(schema, createDummyRow: false);
 
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(trainingData, new ColumnInformation(),
                    validationData, task));
                Assert.StartsWith("Validation data has 0 rows", ex.Message);
            }
        }
 
 
        [Fact]
        public void TestValidationDataSchemaChecksIgnoreHiddenColumns()
        {
            var mlContext = new MLContext(1);
 
            // Build training data where label column is a float.
            var trainDataBuilder = new ArrayDataViewBuilder(mlContext);
            trainDataBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
            trainDataBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
            var trainingData = trainDataBuilder.GetDataView();
 
            // In the training data, transform the label column from a float to a Boolean. This has the effect of
            // creating a hidden column named 'Label' of type float and an additional column named 'Label' of type Boolean.
            var convertLabelToBoolEstimator = mlContext.Transforms.Conversion.MapValue(DefaultColumnNames.Label,
                new List<KeyValuePair<float, bool>>() { new KeyValuePair<float, bool>(1, true) });
            trainingData = convertLabelToBoolEstimator.Fit(trainingData).Transform(trainingData);
 
            // Build validation data where label column is a Boolean.
            var validationDataBuilder = new ArrayDataViewBuilder(mlContext);
            validationDataBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
            validationDataBuilder.AddColumn(DefaultColumnNames.Label, BooleanDataViewType.Instance, false);
            var validationData = validationDataBuilder.GetDataView();
 
            UserInputValidationUtil.ValidateExperimentExecuteArgs(trainingData, new ColumnInformation(), validationData, TaskKind.BinaryClassification);
        }
 
 
        private static void ValidateLabelTypeTestCore<TLabelRawType>(TaskKind task, PrimitiveDataViewType labelType, bool labelTypeShouldBeValid)
        {
            var dataViewBuilder = new ArrayDataViewBuilder(new MLContext(1));
            dataViewBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, 0f);
            if (labelType == TextDataViewType.Instance)
            {
                dataViewBuilder.AddColumn(DefaultColumnNames.Label, string.Empty);
            }
            else
            {
                dataViewBuilder.AddColumn(DefaultColumnNames.Label, labelType, Activator.CreateInstance<TLabelRawType>());
            }
            var dataView = dataViewBuilder.GetDataView();
            var validationExceptionThrown = false;
            try
            {
                UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, task);
            }
            catch
            {
                validationExceptionThrown = true;
            }
            Assert.Equal(labelTypeShouldBeValid, !validationExceptionThrown);
        }
        [Fact]
        public void ValidateTrainDataColumnTestMultipleMismatchLessThan5()
        {
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                var originalColumnName = _data.Schema.First().Name;
                var mistypedColumnName = "a" + originalColumnName + "b";
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
                    new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task));
 
                Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data. Did you mean '{originalColumnName}'.",
                    ex.Message);
            }
        }
 
        [Fact]
        public void ValidateTrainDataColumnTestMultipleMismatchMoreThan5()
        {
            foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
            {
                var originalColumnName = _data.Schema.First().Name;
                var mistypedColumnName = "a" + originalColumnName + "bcdvfnfmsm";
                var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
                    new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task));
 
                Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data.",
                    ex.Message);
            }
        }
 
    }
}