File: ColumnInference\ColumnGroupingInference.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using static Microsoft.ML.Data.TextLoader;
 
using Range = Microsoft.ML.Data.TextLoader.Range;
 
namespace Microsoft.ML.AutoML
{
    /// <summary>
    /// This class encapsulates logic for grouping together the inferred columns of the text file based on their type
    /// and purpose, and generating column names.
    /// </summary>
    internal static class ColumnGroupingInference
    {
        /// <summary>
        /// This is effectively a merger of <see cref="PurposeInference.Column"/> and a <see cref="ColumnTypeInference.Column"/>
        /// with support for vector-value columns.
        /// </summary>
        public class GroupingColumn
        {
            public string SuggestedName;
            public DataKind ItemKind;
            public ColumnPurpose Purpose;
            public Range[] Ranges;
 
            public GroupingColumn(string name, DataKind kind, ColumnPurpose purpose, Range[] ranges)
            {
                SuggestedName = name;
                ItemKind = kind;
                Purpose = purpose;
                Ranges = ranges;
            }
 
            public TextLoader.Column GenerateTextLoaderColumn()
            {
                return new TextLoader.Column(SuggestedName, ItemKind, Ranges);
            }
        }
 
        /// <summary>
        /// Group together the single-valued columns with the same type and purpose and generate column names.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="hasHeader">Whether the original file had a header.
        /// If yes, the <see cref="ColumnTypeInference.Column.SuggestedName"/> fields are used to generate the column
        /// names, otherwise they are ignored.</param>
        /// <param name="types">The (detected) column types.</param>
        /// <param name="purposes">The (detected) column purposes. Must be parallel to <paramref name="types"/>.</param>
        /// <returns>The struct containing an array of grouped columns specifications.</returns>
        public static GroupingColumn[] InferGroupingAndNames(MLContext env, bool hasHeader, ColumnTypeInference.Column[] types, PurposeInference.Column[] purposes)
        {
            var result = new List<GroupingColumn>();
            var tuples = types.Zip(purposes, Tuple.Create).ToList();
            var grouped =
                from t in tuples
                group t by
                    new
                    {
                        t.Item1.ItemType,
                        t.Item2.Purpose,
                        purposeGroupId = GetPurposeGroupId(t.Item1.ColumnIndex, t.Item2.Purpose)
                    }
                    into g
                select g;
 
            foreach (var g in grouped)
            {
                string name = (hasHeader && g.Count() == 1)
                    ? g.First().Item1.SuggestedName
                    : GetName(g.Key.ItemType.GetRawKind().ToDataKind(), g.Key.Purpose, result);
 
                var ranges = GetRanges(g.Select(t => t.Item1.ColumnIndex).ToArray());
                result.Add(new GroupingColumn(name, g.Key.ItemType.GetRawKind().ToDataKind(), g.Key.Purpose, ranges));
            }
 
            return result.ToArray();
        }
 
        private static int GetPurposeGroupId(int columnIndex, ColumnPurpose purpose)
        {
            if (purpose == ColumnPurpose.CategoricalFeature ||
                purpose == ColumnPurpose.TextFeature ||
                purpose == ColumnPurpose.Ignore)
                return columnIndex;
            return 0;
        }
 
        private static string GetName(DataKind itemKind, ColumnPurpose purpose, List<GroupingColumn> previousColumns)
        {
            string prefix = GetPurposeName(purpose, itemKind);
            int i = 0;
            string name = prefix;
            while (previousColumns.Any(x => x.SuggestedName == name))
            {
                i++;
                name = string.Format("{0}{1:00}", prefix, i);
            }
 
            return name;
        }
 
        private static string GetPurposeName(ColumnPurpose purpose, DataKind itemKind)
        {
            switch (purpose)
            {
                case ColumnPurpose.NumericFeature:
                    if (itemKind == DataKind.Boolean)
                    {
                        return "BooleanFeatures";
                    }
                    else
                    {
                        return "Features";
                    }
                case ColumnPurpose.CategoricalFeature:
                    return "Cat";
                default:
                    return Enum.GetName(typeof(ColumnPurpose), purpose);
            }
        }
 
        /// <summary>
        /// Generates a collection of Ranges from indices.
        /// </summary>
        private static Range[] GetRanges(int[] indices)
        {
            Array.Sort(indices);
            var allRanges = new List<Range>();
            var currRange = new Range(indices[0]);
            for (int i = 1; i < indices.Length; i++)
            {
                if (indices[i] == currRange.Max + 1)
                {
                    currRange.Max++;
                }
                else
                {
                    allRanges.Add(currRange);
                    currRange = new Range(indices[i]);
                }
            }
            allRanges.Add(currRange);
            return allRanges.ToArray();
        }
    }
}