File: Utils.cs
Web Access
Project: src\src\Microsoft.ML.CodeGenerator\Microsoft.ML.CodeGenerator.csproj (Microsoft.ML.CodeGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.ML.AutoML;
using Microsoft.ML.CodeGenerator.CSharp;
using Microsoft.ML.Data;
 
namespace Microsoft.ML.CodeGenerator.Utilities
{
    internal class Utils
    {
        internal static string Sanitize(string name)
        {
            return string.Join("", name.Select(x => Char.IsLetterOrDigit(x) ? x : '_'));
        }
 
        /// <summary>
        /// Take the first line of data from inputFile and parse it as a dictionary using schema from columnInference.
        /// </summary>
        /// <param name="inputFile">path to input file.</param>
        /// <param name="columnInference">Column Inferernce Result.</param>
        /// <returns>A dictionary which key is sanitized column name and value is first line of data.</returns>
        internal static IDictionary<string, string> GenerateSampleData(string inputFile, ColumnInferenceResults columnInference)
        {
            try
            {
                var mlContext = new MLContext();
                var textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderOptions);
                var trainData = textLoader.Load(inputFile);
                return Utils.GenerateSampleData(trainData, columnInference);
            }
            catch (Exception)
            {
                return null;
            }
        }
 
        internal static IDictionary<string, string> GenerateSampleData(IDataView dataView, ColumnInferenceResults columnInference)
        {
            var featureColumns = dataView.Schema.ToList().FindAll(
                col => col.Name != columnInference.ColumnInformation.LabelColumnName &&
                       !columnInference.ColumnInformation.IgnoredColumnNames.Contains(col.Name));
            var rowCursor = dataView.GetRowCursor(featureColumns);
 
            OrderedDictionary sampleData = new OrderedDictionary();
            // Get normalized and unique column names. If there are duplicate column names, the
            // differentiator suffix '_col_x' will be added to each column name, where 'x' is
            // the load order for a given column.
            List<string> normalizedColumnNames = GenerateColumnNames(featureColumns.Select(column => column.Name).ToList());
            foreach (string columnName in normalizedColumnNames)
                sampleData[columnName] = null;
            if (rowCursor.MoveNext())
            {
                var getGetGetterMethod = typeof(Utils).GetMethod(nameof(Utils.GetValueFromColumn), BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
 
                // Access each feature column name through its index in featureColumns
                // as there may exist duplicate column names. In this case, sampleData
                // column names may have the differentiator suffix of '_col_x' added,
                // which requires access to each column name in through its index.
                for (int i = 0; i < featureColumns.Count(); i++)
                {
                    var getGenericGetGetterMethod = getGetGetterMethod.MakeGenericMethod(featureColumns[i].Type.RawType);
                    string val = getGenericGetGetterMethod.Invoke(null, new object[] { rowCursor, featureColumns[i] }) as string;
                    sampleData[i] = val;
                }
            }
 
            return sampleData.Cast<DictionaryEntry>().ToDictionary(k => (string)k.Key, v => (string)v.Value);
        }
 
        internal static string GetValueFromColumn<T>(DataViewRowCursor rowCursor, DataViewSchema.Column column)
        {
            T val = default;
            var getter = rowCursor.GetGetter<T>(column);
            getter(ref val);
 
            // wrap string in quotes
            if (typeof(T) == typeof(ReadOnlyMemory<Char>))
            {
                return $"@\"{val.ToString().Replace("\"", "\\\"")}\"";
            }
 
            if (val is null)
            {
                return "\"null\"";
            }
 
            if (val is float)
            {
                var f = val as float?;
                if (Single.IsNaN(f.GetValueOrDefault()))
                {
                    return "Single.NaN";
                }
 
                if (Single.IsPositiveInfinity(f.GetValueOrDefault()))
                {
                    return "Single.PositiveInfinity";
                }
 
                if (Single.IsNegativeInfinity(f.GetValueOrDefault()))
                {
                    return "Single.NegativeInfinity";
                }
 
                return f?.ToString(CultureInfo.InvariantCulture) + "F";
            }
 
            if (val is bool)
            {
                var f = val as bool?;
                return f.GetValueOrDefault() ? "true" : "false";
            }
 
            return val.ToString();
        }
 
        internal static string Normalize(string input)
        {
            //check if first character is int
            if (!string.IsNullOrEmpty(input) && int.TryParse(input.Substring(0, 1), out int val))
            {
                input = "_" + input;
                return Normalize(input);
            }
            switch (input)
            {
                case null: throw new ArgumentNullException(nameof(input));
                case "": throw new ArgumentException($"{nameof(input)} cannot be empty", nameof(input));
                default:
                    var sanitizedInput = Sanitize(input);
                    return sanitizedInput.First().ToString().ToUpper() + sanitizedInput.Substring(1);
            }
        }
 
        internal static Type GetCSharpType(DataKind labelType)
        {
            switch (labelType)
            {
                case Microsoft.ML.Data.DataKind.String:
                    return typeof(string);
                case Microsoft.ML.Data.DataKind.Boolean:
                    return typeof(bool);
                case Microsoft.ML.Data.DataKind.Single:
                    return typeof(float);
                case Microsoft.ML.Data.DataKind.Double:
                    return typeof(double);
                case Microsoft.ML.Data.DataKind.Int32:
                    return typeof(int);
                case Microsoft.ML.Data.DataKind.UInt32:
                    return typeof(uint);
                case Microsoft.ML.Data.DataKind.Int64:
                    return typeof(long);
                case Microsoft.ML.Data.DataKind.UInt64:
                    return typeof(ulong);
                default:
                    throw new ArgumentException($"The data type '{labelType}' is not handled currently.");
            }
        }
 
        internal static void WriteOutputToFiles(string fileContent, string fileName, string outPutBaseDir)
        {
            if (!Directory.Exists(outPutBaseDir))
            {
                Directory.CreateDirectory(outPutBaseDir);
            }
            File.WriteAllText($"{outPutBaseDir}/{fileName}", fileContent);
        }
 
        internal static string FormatCode(string trainProgramCSFileContent)
        {
            //Format
            var tree = CSharpSyntaxTree.ParseText(trainProgramCSFileContent);
            var syntaxNode = tree.GetRoot();
            trainProgramCSFileContent = Formatter.Format(syntaxNode, new AdhocWorkspace()).ToFullString();
            return trainProgramCSFileContent;
        }
 
        internal static int AddProjectsToSolution(string solutionPath, string[] projects)
        {
            var proc = new System.Diagnostics.Process();
            var projectPaths = projects.Select((name) => $"\"{Path.Combine(Path.GetDirectoryName(solutionPath), name).ToString()}\"");
            try
            {
                proc.StartInfo.FileName = @"dotnet";
                proc.StartInfo.Arguments = $"sln \"{solutionPath}\" add {string.Join(" ", projectPaths)}";
                proc.StartInfo.UseShellExecute = false;
                proc.StartInfo.RedirectStandardOutput = true;
                proc.Start();
                string outPut = proc.StandardOutput.ReadToEnd();
                proc.WaitForExit();
                var exitCode = proc.ExitCode;
                return exitCode;
            }
            finally
            {
                proc.Close();
            }
        }
 
        internal static int AddProjectsToSolution(string modelprojectDir,
            string modelProjectName,
            string consoleAppProjectDir,
            string consoleAppProjectName,
            string solutionPath)
        {
            var proc = new System.Diagnostics.Process();
            try
            {
                proc.StartInfo.FileName = @"dotnet";
                proc.StartInfo.Arguments = $"sln \"{solutionPath}\" add  \"{Path.Combine(consoleAppProjectDir, consoleAppProjectName)}\" \"{Path.Combine(modelprojectDir, modelProjectName)}\"";
                proc.StartInfo.UseShellExecute = false;
                proc.StartInfo.RedirectStandardOutput = true;
                proc.Start();
                string outPut = proc.StandardOutput.ReadToEnd();
                proc.WaitForExit();
                var exitCode = proc.ExitCode;
                return exitCode;
            }
            finally
            {
                proc.Close();
            }
        }
 
        internal static int CreateSolutionFile(string solutionFile, string outputPath)
        {
            var proc = new System.Diagnostics.Process();
            try
            {
                proc.StartInfo.FileName = @"dotnet";
                proc.StartInfo.Arguments = $"new sln --name \"{solutionFile}\" --output \"{outputPath}\" --force";
                proc.StartInfo.UseShellExecute = false;
                proc.StartInfo.RedirectStandardOutput = true;
                proc.Start();
                string outPut = proc.StandardOutput.ReadToEnd();
                proc.WaitForExit();
                var exitCode = proc.ExitCode;
                return exitCode;
            }
            finally
            {
                proc.Close();
            }
        }
 
        internal static IList<string> GenerateClassLabels(ColumnInferenceResults columnInferenceResults, IDictionary<string, CodeGeneratorSettings.ColumnMapping> columnMapping = default)
        {
            IList<string> result = new List<string>();
            List<string> columnNames = new List<string>();
            foreach (var column in columnInferenceResults.TextLoaderOptions.Columns)
            {
                StringBuilder sb = new StringBuilder();
                int range = (column.Source[0].Max - column.Source[0].Min).Value;
                bool isArray = range > 0;
                sb.Append(Symbols.PublicSymbol);
                sb.Append(Symbols.Space);
 
                // if column is in columnMapping, use the type and name in that
                DataKind dataKind;
                string columnName;
 
                if (columnMapping != null && columnMapping.ContainsKey(column.Name))
                {
                    dataKind = columnMapping[column.Name].ColumnType;
                    columnName = columnMapping[column.Name].ColumnName;
                }
                else
                {
                    dataKind = column.DataKind;
                    columnName = column.Name;
                }
                sb.Append(GetSymbolOfDataKind(dataKind));
 
                // Accomodate VectorType (array) columns
                if (range > 0)
                {
                    result.Add($"[ColumnName(\"{columnName}\"),LoadColumn({column.Source[0].Min}, {column.Source[0].Max}) VectorType({(range + 1)})]");
                    sb.Append("[]");
                }
                else
                {
                    result.Add($"[ColumnName(\"{columnName}\"), LoadColumn({column.Source[0].Min})]");
                }
                sb.Append(" ");
                columnNames.Add(column.Name);
                result.Add(sb.ToString());
                result.Add("\r\n");
            }
            // Get normalized and unique column names. If there are duplicate column names, the
            // differentiator suffix '_col_x' will be added to each column name, where 'x' is
            // the load order for a given column.
            List<string> normalizedColumnNames = GenerateColumnNames(columnNames);
            for (int i = 1; i < result.Count; i += 3)
            {
                // Get normalized column name for correctly typed class property name
                result[i] += normalizedColumnNames[i / 3];
                result[i] += "{get; set;}";
            }
            return result;
        }
 
        /// <summary>
        /// Take a list of column names that may not be normalized to fit property name standards
        /// and contain duplicate column names. Return unique and normalized column names.
        /// </summary>
        /// <param name="columnNames">Column names to normalize.</param>
        /// <returns>A list of strings that contain normalized and unique column names.</returns>
        internal static List<string> GenerateColumnNames(List<string> columnNames)
        {
            for (int i = 0; i < columnNames.Count; i++)
                columnNames[i] = Utils.Normalize(columnNames[i]);
            // Check if there are any duplicates in columnNames by obtaining its set
            // and seeing whether or not they are the same size.
            HashSet<String> columnNamesSet = new HashSet<String>(columnNames);
            // If there are duplicates, add the differentiator suffix '_col_x'
            // to each normalized column name, where 'x' is the load
            // order for a given column from dataset.
            if (columnNamesSet.Count != columnNames.Count)
            {
                for (int i = 0; i < columnNames.Count; i++)
                    columnNames[i] += String.Concat("_col_", i);
            }
            return columnNames;
        }
 
        internal static string GetSymbolOfDataKind(DataKind dataKind)
        {
            switch (dataKind)
            {
                case DataKind.String:
                    return Symbols.StringSymbol;
                case DataKind.Boolean:
                    return Symbols.BoolSymbol;
                case DataKind.Single:
                    return Symbols.FloatSymbol;
                case DataKind.Double:
                    return Symbols.DoubleSymbol;
                case DataKind.Int32:
                    return Symbols.IntSymbol;
                case DataKind.UInt32:
                    return Symbols.UIntSymbol;
                case DataKind.Int64:
                    return Symbols.LongSymbol;
                case DataKind.UInt64:
                    return Symbols.UlongSymbol;
                default:
                    throw new ArgumentException($"The data type '{dataKind}' is not handled currently.");
            }
        }
    }
}