File: Utils\UserInputValidationUtil.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.AutoML.Utils;
using Microsoft.ML.Data;
 
namespace Microsoft.ML.AutoML
{
    internal static class UserInputValidationUtil
    {
        // column purpose names
        private const string LabelColumnPurposeName = "label";
        private const string WeightColumnPurposeName = "weight";
        private const string NumericColumnPurposeName = "numeric";
        private const string CategoricalColumnPurposeName = "categorical";
        private const string TextColumnPurposeName = "text";
        private const string IgnoredColumnPurposeName = "ignored";
        private const string SamplingKeyColumnPurposeName = "sampling key";
        private const string UserIdColumnPurposeName = "user ID";
        private const string ItemIdColumnPurposeName = "item ID";
        private const string GroupIdColumnPurposeName = "group ID";
 
        public static void ValidateExperimentExecuteArgs(IDataView trainData, ColumnInformation columnInformation,
            IDataView validationData, TaskKind task)
        {
            ValidateTrainData(trainData, columnInformation);
            ValidateColumnInformation(trainData, columnInformation, task);
            ValidateValidationData(trainData, validationData);
        }
 
        public static void ValidateInferColumnsArgs(string path, ColumnInformation columnInformation)
        {
            ValidateColumnInformation(columnInformation);
            ValidatePath(path);
        }
 
        public static void ValidateInferColumnsArgs(string path, string labelColumn)
        {
            ValidateLabelColumn(labelColumn);
            ValidatePath(path);
        }
 
        public static void ValidateInferColumnsArgs(string path)
        {
            ValidatePath(path);
        }
 
        public static void ValidateNumberOfCVFoldsArg(uint numberOfCVFolds)
        {
            if (numberOfCVFolds <= 1)
            {
                throw new ArgumentException($"{nameof(numberOfCVFolds)} must be at least 2", nameof(numberOfCVFolds));
            }
        }
 
        public static void ValidateSamplingKey(string samplingKeyColumnName, string groupIdColumnName, TaskKind task)
        {
            if (task == TaskKind.Ranking && samplingKeyColumnName != null && samplingKeyColumnName != groupIdColumnName)
            {
                throw new ArgumentException($"If provided, {nameof(samplingKeyColumnName)} must be the same as {nameof(groupIdColumnName)} for Ranking Experiments", samplingKeyColumnName);
            }
        }
 
        private static void ValidateTrainData(IDataView trainData, ColumnInformation columnInformation)
        {
            if (trainData == null)
            {
                throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
            }
 
            if (DatasetDimensionsUtil.IsDataViewEmpty(trainData))
            {
                throw new ArgumentException("Training data has 0 rows", nameof(trainData));
            }
 
            foreach (var column in trainData.Schema)
            {
                if (column.Name == DefaultColumnNames.Features && column.Type.GetItemType() != NumberDataViewType.Single)
                {
                    throw new ArgumentException($"{DefaultColumnNames.Features} column must be of data type {NumberDataViewType.Single}", nameof(trainData));
                }
 
                if ((column.Name != columnInformation.LabelColumnName &&
                    column.Name != columnInformation.UserIdColumnName &&
                    column.Name != columnInformation.ItemIdColumnName &&
                    column.Name != columnInformation.GroupIdColumnName)
                    &&
                        column.Type.GetItemType() != BooleanDataViewType.Instance &&
                        column.Type.GetItemType() != NumberDataViewType.Single &&
                        column.Type.GetItemType() != TextDataViewType.Instance)
                {
                    throw new ArgumentException($"Only supported feature column types are " +
                        $"{BooleanDataViewType.Instance}, {NumberDataViewType.Single}, and {TextDataViewType.Instance}. " +
                        $"Please change the feature column {column.Name} of type {column.Type} to one of " +
                        $"the supported types.", nameof(trainData));
                }
            }
        }
 
        private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation, TaskKind task)
        {
            ValidateColumnInformation(columnInformation);
            ValidateTrainDataColumn(trainData, columnInformation.LabelColumnName, LabelColumnPurposeName, GetAllowedLabelTypes(task));
            ValidateTrainDataColumn(trainData, columnInformation.ExampleWeightColumnName, WeightColumnPurposeName);
            ValidateTrainDataColumn(trainData, columnInformation.SamplingKeyColumnName, SamplingKeyColumnPurposeName);
            ValidateTrainDataColumn(trainData, columnInformation.UserIdColumnName, UserIdColumnPurposeName);
            ValidateTrainDataColumn(trainData, columnInformation.ItemIdColumnName, ItemIdColumnPurposeName);
            ValidateTrainDataColumn(trainData, columnInformation.GroupIdColumnName, GroupIdColumnPurposeName);
            ValidateTrainDataColumns(trainData, columnInformation.CategoricalColumnNames, CategoricalColumnPurposeName,
                new DataViewType[] { NumberDataViewType.Single, TextDataViewType.Instance });
            ValidateTrainDataColumns(trainData, columnInformation.NumericColumnNames, NumericColumnPurposeName,
                new DataViewType[] { NumberDataViewType.Single, BooleanDataViewType.Instance });
            ValidateTrainDataColumns(trainData, columnInformation.TextColumnNames, TextColumnPurposeName,
                new DataViewType[] { TextDataViewType.Instance });
            ValidateTrainDataColumns(trainData, columnInformation.IgnoredColumnNames, IgnoredColumnPurposeName);
        }
 
        private static void ValidateColumnInformation(ColumnInformation columnInformation)
        {
            ValidateLabelColumn(columnInformation.LabelColumnName);
 
            ValidateColumnInfoEnumerationProperty(columnInformation.CategoricalColumnNames, CategoricalColumnPurposeName);
            ValidateColumnInfoEnumerationProperty(columnInformation.NumericColumnNames, NumericColumnPurposeName);
            ValidateColumnInfoEnumerationProperty(columnInformation.TextColumnNames, TextColumnPurposeName);
            ValidateColumnInfoEnumerationProperty(columnInformation.IgnoredColumnNames, IgnoredColumnPurposeName);
 
            // keep a list of all columns, to detect duplicates
            var allColumns = new List<string>();
            allColumns.Add(columnInformation.LabelColumnName);
            if (columnInformation.ExampleWeightColumnName != null) { allColumns.Add(columnInformation.ExampleWeightColumnName); }
            if (columnInformation.CategoricalColumnNames != null) { allColumns.AddRange(columnInformation.CategoricalColumnNames); }
            if (columnInformation.NumericColumnNames != null) { allColumns.AddRange(columnInformation.NumericColumnNames); }
            if (columnInformation.TextColumnNames != null) { allColumns.AddRange(columnInformation.TextColumnNames); }
            if (columnInformation.IgnoredColumnNames != null) { allColumns.AddRange(columnInformation.IgnoredColumnNames); }
 
            var duplicateColName = FindFirstDuplicate(allColumns);
            if (duplicateColName != null)
            {
                throw new ArgumentException($"Duplicate column name {duplicateColName} is present in two or more distinct properties of provided column information", nameof(columnInformation));
            }
        }
 
        private static void ValidateColumnInfoEnumerationProperty(IEnumerable<string> columns, string columnPurpose)
        {
            if (columns?.Contains(null) == true)
            {
                throw new ArgumentException($"Null column string was specified as {columnPurpose} in column information");
            }
        }
 
        private static void ValidateLabelColumn(string labelColumn)
        {
            if (labelColumn == null)
            {
                throw new ArgumentException("Provided label column cannot be null");
            }
        }
 
        private static void ValidatePath(string path)
        {
            if (path == null)
            {
                throw new ArgumentNullException(nameof(path), "Provided path cannot be null");
            }
 
            var fileInfo = new FileInfo(path);
 
            if (!fileInfo.Exists)
            {
                throw new ArgumentException($"File '{path}' does not exist", nameof(path));
            }
 
            if (fileInfo.Length == 0)
            {
                throw new ArgumentException($"File at path '{path}' cannot be empty", nameof(path));
            }
        }
 
        private static void ValidateValidationData(IDataView trainData, IDataView validationData)
        {
            if (validationData == null)
            {
                return;
            }
 
            if (DatasetDimensionsUtil.IsDataViewEmpty(validationData))
            {
                throw new ArgumentException("Validation data has 0 rows", nameof(validationData));
            }
 
            const string schemaMismatchError = "Training data and validation data schemas do not match.";
 
            if (trainData.Schema.Count(c => !c.IsHidden) != validationData.Schema.Count(c => !c.IsHidden))
            {
                throw new ArgumentException($"{schemaMismatchError} Train data has '{trainData.Schema.Count}' columns," +
                    $"and validation data has '{validationData.Schema.Count}' columns.", nameof(validationData));
            }
 
            // Validate that every active column in the train data corresponds to an active column in the validation data.
            // (Indirectly, since we asserted above that the train and validation data have the same number of active columns, this also
            // ensures the reverse -- that every active column in the validation data corresponds to an active column in the train data.)
            foreach (var trainCol in trainData.Schema)
            {
                if (trainCol.IsHidden)
                {
                    continue;
                }
 
                var validCol = validationData.Schema.GetColumnOrNull(trainCol.Name);
                if (validCol == null)
                {
                    throw new ArgumentException($"{schemaMismatchError} Column '{trainCol.Name}' exists in train data, but not in validation data.", nameof(validationData));
                }
 
                if (trainCol.Type != validCol.Value.Type && !trainCol.Type.Equals(validCol.Value.Type))
                {
                    throw new ArgumentException($"{schemaMismatchError} Column '{trainCol.Name}' is of type {trainCol.Type} in train data, and type " +
                        $"{validCol.Value.Type} in validation data.", nameof(validationData));
                }
            }
        }
 
        private static void ValidateTrainDataColumns(IDataView trainData, IEnumerable<string> columnNames, string columnPurpose,
            IEnumerable<DataViewType> allowedTypes = null)
        {
            if (columnNames == null)
            {
                return;
            }
 
            foreach (var columnName in columnNames)
            {
                ValidateTrainDataColumn(trainData, columnName, columnPurpose, allowedTypes);
            }
        }
 
        private static void ValidateTrainDataColumn(IDataView trainData, string columnName, string columnPurpose, IEnumerable<DataViewType> allowedTypes = null)
        {
            if (columnName == null)
            {
                return;
            }
 
            var nullableColumn = trainData.Schema.GetColumnOrNull(columnName);
            if (nullableColumn == null)
            {
                var closestNamed = ClosestNamed(trainData, columnName, 7);
 
                var exceptionMessage = $"Provided {columnPurpose} column '{columnName}' not found in training data.";
                if (closestNamed != string.Empty)
                {
                    exceptionMessage += $" Did you mean '{closestNamed}'.";
                }
 
                throw new ArgumentException(exceptionMessage);
            }
 
            if (allowedTypes == null)
            {
                return;
            }
            var column = nullableColumn.Value;
            var itemType = column.Type.GetItemType();
            if (!allowedTypes.Contains(itemType))
            {
                if (allowedTypes.Count() == 1)
                {
                    throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' was of type {itemType}, " +
                        $"but only type {allowedTypes.First()} is allowed.");
                }
                else
                {
                    throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' was of type {itemType}, " +
                        $"but only types {string.Join(", ", allowedTypes)} are allowed.");
                }
            }
        }
 
        private static string ClosestNamed(IDataView trainData, string columnName, int maxAllowableEditDistance = int.MaxValue)
        {
            var minEditDistance = int.MaxValue;
            var closestNamed = string.Empty;
            foreach (var column in trainData.Schema)
            {
                var editDistance = StringEditDistance.GetLevenshteinDistance(column.Name, columnName);
                if (editDistance < minEditDistance)
                {
                    minEditDistance = editDistance;
                    closestNamed = column.Name;
                }
            }
 
            return minEditDistance <= maxAllowableEditDistance ? closestNamed : string.Empty;
        }
 
        private static string FindFirstDuplicate(IEnumerable<string> values)
        {
            var groups = values.GroupBy(v => v);
            return groups.FirstOrDefault(g => g.Count() > 1)?.Key;
        }
 
        private static IEnumerable<DataViewType> GetAllowedLabelTypes(TaskKind task)
        {
            switch (task)
            {
                case TaskKind.BinaryClassification:
                    return new DataViewType[] { BooleanDataViewType.Instance };
                // Multiclass label types are flexible, as we convert the label to a key type
                // (if input label is not already a key) before invoking the trainer.
                case TaskKind.MulticlassClassification:
                    return null;
                case TaskKind.Regression:
                case TaskKind.Recommendation:
                    return new DataViewType[] { NumberDataViewType.Single };
                case TaskKind.Ranking:
                    return new DataViewType[] { NumberDataViewType.Single };
                default:
                    throw new NotSupportedException($"Unsupported task type: {task}");
            }
        }
    }
}