File: ColumnInference\ColumnInferenceApi.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
 
namespace Microsoft.ML.AutoML
{
    internal static class ColumnInferenceApi
    {
        public static ColumnInferenceResults InferColumns(MLContext context, string path, uint labelColumnIndex,
            bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
        {
            var sample = TextFileSample.CreateFromFullFile(path);
            var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
            var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader, labelColumnIndex, null);
 
            // If no headers, suggest label column name as 'Label'
            if (!hasHeader)
            {
                typeInference.Columns[labelColumnIndex].SuggestedName = DefaultColumnNames.Label;
            }
 
            var columnInfo = new ColumnInformation() { LabelColumnName = typeInference.Columns[labelColumnIndex].SuggestedName };
 
            return InferColumns(context, path, columnInfo, hasHeader, splitInference, typeInference, trimWhitespace, groupColumns);
        }
 
        public static ColumnInferenceResults InferColumns(MLContext context, string path, string labelColumn,
            char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
        {
            var columnInfo = new ColumnInformation() { LabelColumnName = labelColumn };
            return InferColumns(context, path, columnInfo, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
        }
 
        public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo,
            char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns, bool hasHeader = true)
        {
            var sample = TextFileSample.CreateFromFullFile(path);
            var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
            var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader, null, columnInfo.LabelColumnName);
            return InferColumns(context, path, columnInfo, hasHeader, splitInference, typeInference, trimWhitespace, groupColumns);
        }
 
        public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo, bool hasHeader,
            TextFileContents.ColumnSplitResult splitInference, ColumnTypeInference.InferenceResult typeInference,
            bool trimWhitespace, bool groupColumns)
        {
            var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns);
            var typedLoaderOptions = new TextLoader.Options
            {
                Columns = loaderColumns,
                Separators = new[] { splitInference.Separator.Value },
                AllowSparse = splitInference.AllowSparse,
                AllowQuoting = splitInference.AllowQuote,
                ReadMultilines = splitInference.ReadMultilines,
                HasHeader = hasHeader,
                TrimWhitespace = trimWhitespace
            };
            var textLoader = context.Data.CreateTextLoader(typedLoaderOptions);
            var dataView = textLoader.Load(path);
 
            // Validate all columns specified in column info exist in inferred data view
            ColumnInferenceValidationUtil.ValidateSpecifiedColumnsExist(columnInfo, dataView);
 
            var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, columnInfo);
 
            // start building result objects
            IEnumerable<TextLoader.Column> columnResults = null;
            IEnumerable<(string, ColumnPurpose)> purposeResults = null;
 
            // infer column grouping and generate column names
            if (groupColumns)
            {
                var groupingResult = ColumnGroupingInference.InferGroupingAndNames(context, hasHeader,
                    typeInference.Columns, purposeInferenceResult);
 
                columnResults = groupingResult.Select(c => c.GenerateTextLoaderColumn());
                purposeResults = groupingResult.Select(c => (c.SuggestedName, c.Purpose));
            }
            else
            {
                columnResults = loaderColumns;
                purposeResults = purposeInferenceResult.Select(p => (dataView.Schema[p.ColumnIndex].Name, p.Purpose));
            }
 
            var textLoaderOptions = new TextLoader.Options()
            {
                Columns = columnResults.ToArray(),
                AllowQuoting = splitInference.AllowQuote,
                AllowSparse = splitInference.AllowSparse,
                Separators = new char[] { splitInference.Separator.Value },
                ReadMultilines = splitInference.ReadMultilines,
                HasHeader = hasHeader,
                TrimWhitespace = trimWhitespace
            };
 
            return new ColumnInferenceResults()
            {
                TextLoaderOptions = textLoaderOptions,
                ColumnInformation = ColumnInformationUtil.BuildColumnInfo(purposeResults)
            };
        }
 
        private static TextFileContents.ColumnSplitResult InferSplit(MLContext context, TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse)
        {
            var separatorCandidates = separatorChar == null ? TextFileContents.DefaultSeparators : new char[] { separatorChar.Value };
            var splitInference = TextFileContents.TrySplitColumns(context, sample, separatorCandidates);
 
            // respect passed-in overrides
            if (allowQuotedStrings != null)
            {
                splitInference.AllowQuote = allowQuotedStrings.Value;
            }
            if (supportSparse != null)
            {
                splitInference.AllowSparse = supportSparse.Value;
            }
 
            if (!splitInference.IsSuccess)
            {
                throw new InferenceException(InferenceExceptionType.ColumnSplit,
                    "Unable to split the file provided into multiple, consistent columns. " +
                    "Readable formats include delimited files such as CSV/TSV. " +
                    "Check for a consistent number of columns and proper escaping and quoting.");
            }
 
            return splitInference;
        }
 
        private static ColumnTypeInference.InferenceResult InferColumnTypes(MLContext context, TextFileSample sample,
            TextFileContents.ColumnSplitResult splitInference, bool hasHeader, uint? labelColumnIndex, string label)
        {
            // infer column types
            var typeInferenceResult = ColumnTypeInference.InferTextFileColumnTypes(context, sample,
                new ColumnTypeInference.Arguments
                {
                    ColumnCount = splitInference.ColumnCount,
                    Separator = splitInference.Separator.Value,
                    AllowSparse = splitInference.AllowSparse,
                    AllowQuote = splitInference.AllowQuote,
                    ReadMultilines = splitInference.ReadMultilines,
                    HasHeader = hasHeader,
                    LabelColumnIndex = labelColumnIndex,
                    Label = label
                });
 
            if (!typeInferenceResult.IsSuccess)
            {
                throw new InferenceException(InferenceExceptionType.ColumnDataType, "Unable to infer column types of the file provided.");
            }
 
            return typeInferenceResult;
        }
    }
}