File: ColumnInference\ColumnInformationUtil.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
 
namespace Microsoft.ML.AutoML
{
    internal static class ColumnInformationUtil
    {
        internal static ColumnPurpose? GetColumnPurpose(this ColumnInformation columnInfo, string columnName)
        {
            if (columnName == columnInfo.LabelColumnName)
            {
                return ColumnPurpose.Label;
            }
 
            if (columnName == columnInfo.ExampleWeightColumnName)
            {
                return ColumnPurpose.Weight;
            }
 
            if (columnName == columnInfo.GroupIdColumnName)
            {
                return ColumnPurpose.GroupId;
            }
 
            if (columnName == columnInfo.SamplingKeyColumnName)
            {
                return ColumnPurpose.SamplingKey;
            }
 
            if (columnInfo.CategoricalColumnNames.Contains(columnName))
            {
                return ColumnPurpose.CategoricalFeature;
            }
 
            if (columnInfo.NumericColumnNames.Contains(columnName))
            {
                return ColumnPurpose.NumericFeature;
            }
 
            if (columnInfo.TextColumnNames.Contains(columnName))
            {
                return ColumnPurpose.TextFeature;
            }
 
            if (columnInfo.IgnoredColumnNames.Contains(columnName))
            {
                return ColumnPurpose.Ignore;
            }
 
            if (columnName == columnInfo.UserIdColumnName)
            {
                return ColumnPurpose.UserId;
            }
 
            if (columnName == columnInfo.ItemIdColumnName)
            {
                return ColumnPurpose.ItemId;
            }
 
            if (columnInfo.ImagePathColumnNames.Contains(columnName))
            {
                return ColumnPurpose.ImagePath;
            }
 
            return null;
        }
 
        internal static ColumnInformation BuildColumnInfo(IEnumerable<(string name, ColumnPurpose purpose)> columnPurposes)
        {
            var columnInfo = new ColumnInformation();
 
            foreach (var column in columnPurposes)
            {
                switch (column.purpose)
                {
                    case ColumnPurpose.Label:
                        columnInfo.LabelColumnName = column.name;
                        break;
                    case ColumnPurpose.Weight:
                        columnInfo.ExampleWeightColumnName = column.name;
                        break;
                    case ColumnPurpose.SamplingKey:
                        columnInfo.SamplingKeyColumnName = column.name;
                        break;
                    case ColumnPurpose.CategoricalFeature:
                        columnInfo.CategoricalColumnNames.Add(column.name);
                        break;
                    case ColumnPurpose.Ignore:
                        columnInfo.IgnoredColumnNames.Add(column.name);
                        break;
                    case ColumnPurpose.NumericFeature:
                        columnInfo.NumericColumnNames.Add(column.name);
                        break;
                    case ColumnPurpose.UserId:
                        columnInfo.UserIdColumnName = column.name;
                        break;
                    case ColumnPurpose.ItemId:
                        columnInfo.ItemIdColumnName = column.name;
                        break;
                    case ColumnPurpose.GroupId:
                        columnInfo.GroupIdColumnName = column.name;
                        break;
                    case ColumnPurpose.TextFeature:
                        columnInfo.TextColumnNames.Add(column.name);
                        break;
                    case ColumnPurpose.ImagePath:
                        columnInfo.ImagePathColumnNames.Add(column.name);
                        break;
                }
            }
 
            return columnInfo;
        }
 
        public static ColumnInformation BuildColumnInfo(IEnumerable<DatasetColumnInfo> columns)
        {
            return BuildColumnInfo(columns.Select(c => (c.Name, c.Purpose)));
        }
 
        /// <summary>
        /// Gets all column names that are in <paramref name="columnInformation"/>.
        /// </summary>
        /// <param name="columnInformation">Column information.</param>
        public static IEnumerable<string> GetColumnNames(ColumnInformation columnInformation)
        {
            var columnNames = new List<string>();
            AddStringToListIfNotNull(columnNames, columnInformation.LabelColumnName);
            AddStringToListIfNotNull(columnNames, columnInformation.UserIdColumnName);
            AddStringToListIfNotNull(columnNames, columnInformation.ItemIdColumnName);
            AddStringToListIfNotNull(columnNames, columnInformation.GroupIdColumnName);
            AddStringToListIfNotNull(columnNames, columnInformation.ExampleWeightColumnName);
            AddStringToListIfNotNull(columnNames, columnInformation.SamplingKeyColumnName);
            AddStringsToListIfNotNull(columnNames, columnInformation.CategoricalColumnNames);
            AddStringsToListIfNotNull(columnNames, columnInformation.IgnoredColumnNames);
            AddStringsToListIfNotNull(columnNames, columnInformation.NumericColumnNames);
            AddStringsToListIfNotNull(columnNames, columnInformation.TextColumnNames);
            AddStringsToListIfNotNull(columnNames, columnInformation.ImagePathColumnNames);
            return columnNames;
        }
 
        public static IDictionary<ColumnPurpose, int> CountColumnsByPurpose(ColumnInformation columnInformation)
        {
            var result = new Dictionary<ColumnPurpose, int>();
            var columnNames = GetColumnNames(columnInformation);
            foreach (var columnName in columnNames)
            {
                var purpose = columnInformation.GetColumnPurpose(columnName);
                if (purpose == null)
                {
                    continue;
                }
 
                result.TryGetValue(purpose.Value, out int count);
                result[purpose.Value] = ++count;
            }
            return result;
        }
 
        private static void AddStringsToListIfNotNull(List<string> list, IEnumerable<string> strings)
        {
            foreach (var str in strings)
            {
                AddStringToListIfNotNull(list, str);
            }
        }
 
        private static void AddStringToListIfNotNull(List<string> list, string str)
        {
            if (str != null)
            {
                list.Add(str);
            }
        }
    }
}