File: CodeGenerator\CSharp\CodeGenerator.cs
Web Access
Project: src\src\Microsoft.ML.CodeGenerator\Microsoft.ML.CodeGenerator.csproj (Microsoft.ML.CodeGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.ML.AutoML;
using Microsoft.ML.CodeGenerator.CodeGenerator;
using Microsoft.ML.CodeGenerator.Templates.Azure.Model;
using Microsoft.ML.CodeGenerator.Templates.Console;
using Microsoft.ML.CodeGenerator.Utilities;
using Microsoft.ML.Data;
 
namespace Microsoft.ML.CodeGenerator.CSharp
{
    internal class CodeGenerator : IProjectGenerator
    {
        private readonly Pipeline _pipeline;
        private readonly CodeGeneratorSettings _settings;
        private readonly ColumnInferenceResults _columnInferenceResult;
        private static readonly HashSet<string> _recommendationTrainers = new HashSet<string>() { TrainerName.MatrixFactorization.ToString() };
        private static readonly HashSet<string> _lightGbmTrainers = new HashSet<string>() { TrainerName.LightGbmBinary.ToString(), TrainerName.LightGbmMulti.ToString(),
            TrainerName.LightGbmRegression.ToString(), TrainerName.LightGbmRanking.ToString() };
        private static readonly HashSet<string> _mklComponentsTrainers = new HashSet<string>() { TrainerName.OlsRegression.ToString(), TrainerName.SymbolicSgdLogisticRegressionBinary.ToString() };
        private static readonly HashSet<string> _fastTreeTrainers = new HashSet<string>() { TrainerName.FastForestBinary.ToString(), TrainerName.FastForestRegression.ToString(),
            TrainerName.FastTreeBinary.ToString(), TrainerName.FastTreeRegression.ToString(), TrainerName.FastTreeTweedieRegression.ToString(), TrainerName.FastTreeRanking.ToString() };
        private static readonly HashSet<string> _imageTransformers = new HashSet<string>() { EstimatorName.RawByteImageLoading.ToString() };
        private static readonly HashSet<string> _imageClassificationTrainers = new HashSet<string>() { TrainerName.ImageClassification.ToString() };
 
        internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResult, CodeGeneratorSettings settings)
        {
            _pipeline = pipeline;
            _columnInferenceResult = columnInferenceResult;
            _settings = settings;
        }
 
        public void GenerateOutput()
        {
            // Get the extra nuget packages to be included in the generated project.
            var trainerNodes = _pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer);
 
            bool includeLightGbmPackage = false;
            bool includeMklComponentsPackage = false;
            bool includeFastTreeePackage = false;
            bool includeImageTransformerPackage = false;
            bool includeImageClassificationPackage = false;
            bool includeRecommenderPackage = false;
            // Get the extra nuget packages to be included in the generated project.
            SetRequiredNugetPackages(_pipeline.Nodes, ref includeLightGbmPackage, ref includeMklComponentsPackage,
                ref includeFastTreeePackage, ref includeImageTransformerPackage, ref includeImageClassificationPackage, ref includeRecommenderPackage);
 
            // Get Namespace
            var namespaceValue = Utils.Normalize(_settings.OutputName);
            var labelType = _columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == _settings.LabelName).First().DataKind;
            Type labelTypeCsharp = Utils.GetCSharpType(labelType);
 
            // Generate Model Project
            var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp,
                includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage,
                includeImageTransformerPackage, includeImageClassificationPackage, includeRecommenderPackage);
 
            // Write files to disk.
            var modelprojectDir = Path.Combine(_settings.OutputBaseDir, $"{_settings.OutputName}.Model");
            var dataModelsDir = modelprojectDir;
            var modelProjectName = $"{_settings.OutputName}.Model.csproj";
 
            Utils.WriteOutputToFiles(modelProjectContents.ModelInputCSFileContent, "ModelInput.cs", dataModelsDir);
            Utils.WriteOutputToFiles(modelProjectContents.ModelOutputCSFileContent, "ModelOutput.cs", dataModelsDir);
            Utils.WriteOutputToFiles(modelProjectContents.ConsumeModelCSFileContent, "ConsumeModel.cs", dataModelsDir);
            Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir);
 
            // Generate ConsoleApp Project
            var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp,
                includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage,
                includeImageTransformerPackage, includeImageClassificationPackage, includeRecommenderPackage);
 
            // Write files to disk.
            var consoleAppProjectDir = Path.Combine(_settings.OutputBaseDir, $"{_settings.OutputName}.ConsoleApp");
            var consoleAppProjectName = $"{_settings.OutputName}.ConsoleApp.csproj";
 
            Utils.WriteOutputToFiles(consoleAppProjectContents.ConsoleAppProgramCSFileContent, "Program.cs", consoleAppProjectDir);
            Utils.WriteOutputToFiles(consoleAppProjectContents.modelBuilderCSFileContent, "ModelBuilder.cs", consoleAppProjectDir);
            Utils.WriteOutputToFiles(consoleAppProjectContents.ConsoleAppProjectFileContent, consoleAppProjectName, consoleAppProjectDir);
 
            // New solution file.
            Utils.CreateSolutionFile(_settings.OutputName, _settings.OutputBaseDir);
 
            // Add projects to solution
            var solutionPath = Path.Combine(_settings.OutputBaseDir, $"{_settings.OutputName}.sln");
            Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath);
        }
 
        private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage,
            ref bool includeMklComponentsPackage, ref bool includeFastTreePackage,
            ref bool includeImageTransformerPackage, ref bool includeImageClassificationPackage, ref bool includeRecommenderPackage)
        {
            foreach (var node in trainerNodes)
            {
                PipelineNode currentNode = node;
                if (currentNode.Name == TrainerName.Ova.ToString())
                {
                    currentNode = (PipelineNode)currentNode.Properties["BinaryTrainer"];
                }
 
                if (_lightGbmTrainers.Contains(currentNode.Name))
                {
                    includeLightGbmPackage = true;
                }
                else if (_mklComponentsTrainers.Contains(currentNode.Name))
                {
                    includeMklComponentsPackage = true;
                }
                else if (_fastTreeTrainers.Contains(currentNode.Name))
                {
                    includeFastTreePackage = true;
                }
                else if (_imageTransformers.Contains(currentNode.Name))
                {
                    includeImageTransformerPackage = true;
                }
                else if (_imageClassificationTrainers.Contains(currentNode.Name))
                {
                    includeImageClassificationPackage = true;
                }
                else if (_recommendationTrainers.Contains(currentNode.Name))
                {
                    includeRecommenderPackage = true;
                }
            }
        }
 
        internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent,
            string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue,
                Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage,
                bool includeFastTreePackage, bool includeImageTransformerPackage,
                bool includeImageClassificationPackage, bool includeRecommenderPackage, bool includeOnnxPackage = false, bool includeResNet18Package = false)
        {
            var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue);
            predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent);
 
            var predictProjectFileContent = GeneratPredictProjectFileContent(_settings.OutputName,
                includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage,
                includeImageTransformerPackage, includeImageClassificationPackage, includeRecommenderPackage, includeOnnxPackage, includeResNet18Package,
                _settings.StablePackageVersion, _settings.UnstablePackageVersion, _settings.Target);
 
            var transformsAndTrainers = GenerateTransformsAndTrainers();
            var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod,
                transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, _pipeline.CacheBeforeTrainer, labelTypeCsharp.Name, includeOnnxPackage);
            modelBuilderCSFileContent = Utils.FormatCode(modelBuilderCSFileContent);
 
            return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent);
        }
 
        internal (string ModelInputCSFileContent, string ModelOutputCSFileContent, string ConsumeModelCSFileContent,
            string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue,
                Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage,
                bool includeFastTreePackage, bool includeImageTransformerPackage,
                bool includeImageClassificationPackage, bool includeRecommenderPackage, bool includeOnnxModel = false)
        {
            var classLabels = GenerateClassLabels();
 
            // generate ModelInput.cs
            var modelInputCSFileContent = GenerateModelInputCSFileContent(namespaceValue, classLabels);
            modelInputCSFileContent = Utils.FormatCode(modelInputCSFileContent);
 
            // generate ModelOutput.cs
            var modelOutputCSFileContent = GenerateModelOutputCSFileContent(labelTypeCsharp.Name, namespaceValue);
            modelOutputCSFileContent = Utils.FormatCode(modelOutputCSFileContent);
 
            // generate ConsumeModel.cs
            var consumeModelCSFileContent = GenerateConsumeModelCSFileContent(namespaceValue);
            consumeModelCSFileContent = Utils.FormatCode(consumeModelCSFileContent);
            var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage,
                includeMklComponentsPackage, includeFastTreePackage, includeImageTransformerPackage,
                includeImageClassificationPackage, includeRecommenderPackage, includeOnnxModel,
                _settings.StablePackageVersion, _settings.UnstablePackageVersion, _settings.Target);
 
            return (modelInputCSFileContent, modelOutputCSFileContent, consumeModelCSFileContent, modelProjectFileContent);
        }
 
        private string GenerateConsumeModelCSFileContent(string namespaceValue)
        {
            ConsumeModel consumeModel = new ConsumeModel()
            {
                Namespace = namespaceValue,
                Target = _settings.Target,
                MLNetModelName = _settings.ModelName,
            };
            return consumeModel.TransformText();
        }
 
        internal (string Usings, string TrainerMethod, List<string> PreTrainerTransforms, List<string> PostTrainerTransforms) GenerateTransformsAndTrainers()
        {
            StringBuilder usingsBuilder = new StringBuilder();
            var usings = new List<string>();
 
            // Get pre-trainer transforms
            var nodes = _pipeline.Nodes.TakeWhile(t => t.NodeType == PipelineNodeType.Transform);
            var preTrainerTransformsAndUsings = GenerateTransformsAndUsings(nodes);
 
            // Get post trainer transforms
            nodes = _pipeline.Nodes.SkipWhile(t => t.NodeType == PipelineNodeType.Transform)
                .SkipWhile(t => t.NodeType == PipelineNodeType.Trainer) //skip the trainer
                .TakeWhile(t => t.NodeType == PipelineNodeType.Transform); //post trainer transforms
            var postTrainerTransformsAndUsings = GenerateTransformsAndUsings(nodes);
 
            //Get trainer code and its associated usings.
            (string trainerMethod, string[] trainerUsings) = GenerateTrainerAndUsings();
            if (trainerUsings != null)
            {
                usings.AddRange(trainerUsings);
            }
 
            //Get transforms code and its associated (unique) usings.
            var preTrainerTransforms = preTrainerTransformsAndUsings?.Select(t => t.Item1).ToList();
            var postTrainerTransforms = postTrainerTransformsAndUsings?.Select(t => t.Item1).ToList();
            usings.AddRange(preTrainerTransformsAndUsings.Where(t => t.Item2 != null).SelectMany(t => t.Item2));
            usings.AddRange(postTrainerTransformsAndUsings.Where(t => t.Item2 != null).SelectMany(t => t.Item2));
            usings = usings.Distinct().ToList();
 
            //Combine all using statements to actual text.
            usingsBuilder = new StringBuilder();
            usings.ForEach(t =>
            {
                if (t != null)
                    usingsBuilder.Append(t);
            });
 
            return (usingsBuilder.ToString(), trainerMethod, preTrainerTransforms, postTrainerTransforms);
        }
 
        internal IList<(string, string[])> GenerateTransformsAndUsings(IEnumerable<PipelineNode> nodes)
        {
            //var nodes = pipeline.Nodes.TakeWhile(t => t.NodeType == PipelineNodeType.Transform);
            //var nodes = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Transform);
            var results = new List<(string, string[])>();
            foreach (var node in nodes)
            {
                ITransformGenerator generator = TransformGeneratorFactory.GetInstance(node);
                results.Add((generator.GenerateTransformer(), generator.GenerateUsings()));
            }
 
            return results;
        }
 
        internal (string, string[]) GenerateTrainerAndUsings()
        {
            if (_pipeline == null)
                throw new ArgumentNullException(nameof(_pipeline));
            var node = _pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer).First();
            ITrainerGenerator generator = TrainerGeneratorFactory.GetInstance(node);
            var trainerString = generator.GenerateTrainer();
            var trainerUsings = generator.GenerateUsings();
            return (trainerString, trainerUsings);
        }
 
        /// <summary>
        /// Utilize <see cref="ML.CodeGenerator.Utilities.Utils.GenerateClassLabels(ColumnInferenceResults, IDictionary{string, CodeGeneratorSettings.ColumnMapping})"/>
        /// </summary>
        internal IList<string> GenerateClassLabels(IDictionary<string, CodeGeneratorSettings.ColumnMapping> columnMapping = default)
        {
            return Utils.GenerateClassLabels(_columnInferenceResult, columnMapping);
        }
 
        #region Model project
        private string GenerateModelProjectFileContent(bool includeLightGbmPackage,
            bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeImageTransformerPackage,
                bool includeImageClassificationPackage, bool includeRecommenderPackage, bool includeOnnxModel,
                string stablePackageVersion, string unstablePackageVersion, GenerateTarget target)
        {
            ModelProject modelProject = new ModelProject()
            {
                IncludeLightGBMPackage = includeLightGbmPackage,
                IncludeMklComponentsPackage = includeMklComponentsPackage,
                IncludeFastTreePackage = includeFastTreePackage,
                IncludeImageTransformerPackage = includeImageTransformerPackage,
                IncludeImageClassificationPackage = includeImageClassificationPackage,
                IncludeOnnxModel = includeOnnxModel,
                IncludeRecommenderPackage = includeRecommenderPackage,
                StablePackageVersion = stablePackageVersion,
                UnstablePackageVersion = unstablePackageVersion,
                OnnxRuntimePackageVersion = _settings.OnnxRuntimePackageVersion,
                Target = target,
            };
 
            return modelProject.TransformText();
        }
 
        private string GenerateModelOutputCSFileContent(string predictionLabelType, string namespaceValue)
        {
            ModelOutputClass modelOutputClass = new ModelOutputClass()
            {
                TaskType = _settings.MlTask.ToString(),
                PredictionLabelType = predictionLabelType,
                Namespace = namespaceValue,
                Target = _settings.Target,
            };
            return modelOutputClass.TransformText();
        }
 
        private string GenerateModelInputCSFileContent(string namespaceValue, IList<string> classLabels)
        {
            ModelInputClass modelInputClass = new ModelInputClass()
            {
                Namespace = namespaceValue,
                ClassLabels = classLabels,
                Target = _settings.Target,
            };
            return modelInputClass.TransformText();
        }
        #endregion
 
        #region Predict Project
        private string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage,
            bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeImageTransformerPackage,
                bool includeImageClassificationPackage, bool includeRecommenderPackage,
                bool includeOnnxPackage, bool includeResNet18Package,
                string stablePackageVersion, string unstablePackageVersion, GenerateTarget target)
        {
            var predictProjectFileContent = new PredictProject()
            {
                Namespace = namespaceValue,
                IncludeMklComponentsPackage = includeMklComponentsPackage,
                IncludeLightGBMPackage = includeLightGbmPackage,
                IncludeFastTreePackage = includeFastTreePackage,
                IncludeImageTransformerPackage = includeImageTransformerPackage,
                IncludeImageClassificationPackage = includeImageClassificationPackage,
                IncludeOnnxPackage = includeOnnxPackage,
                IncludeResNet18Package = includeResNet18Package,
                IncludeRecommenderPackage = includeRecommenderPackage,
                StablePackageVersion = stablePackageVersion,
                UnstablePackageVersion = unstablePackageVersion,
                OnnxRuntimePackageVersion = _settings.OnnxRuntimePackageVersion,
                Target = target,
            };
            return predictProjectFileContent.TransformText();
        }
 
        private string GeneratePredictProgramCSFileContent(string namespaceValue)
        {
            var sampleData = Utils.GenerateSampleData(_settings.TrainDataset, _columnInferenceResult);
            PredictProgram predictProgram = new PredictProgram()
            {
                TaskType = _settings.MlTask.ToString(),
                LabelName = _settings.LabelName,
                Namespace = namespaceValue,
                HasHeader = _columnInferenceResult.TextLoaderOptions.HasHeader,
                Separator = _columnInferenceResult.TextLoaderOptions.Separators.FirstOrDefault(),
                AllowQuoting = _columnInferenceResult.TextLoaderOptions.AllowQuoting,
                AllowSparse = _columnInferenceResult.TextLoaderOptions.AllowSparse,
                Target = _settings.Target,
                SampleData = sampleData,
            };
            return predictProgram.TransformText();
        }
 
        private string GenerateModelBuilderCSFileContent(string usings,
            string trainerMethod,
            List<string> preTrainerTransforms,
            List<string> postTrainerTransforms,
            string namespaceValue,
            bool cacheBeforeTrainer,
            string predictionLabelType,
            bool hasOnnxModel = false)
        {
            var modelBuilder = new ModelBuilder()
            {
                PreTrainerTransforms = preTrainerTransforms,
                PostTrainerTransforms = postTrainerTransforms,
                HasHeader = _columnInferenceResult.TextLoaderOptions.HasHeader,
                Separator = _columnInferenceResult.TextLoaderOptions.Separators.FirstOrDefault(),
                AllowQuoting = _columnInferenceResult.TextLoaderOptions.AllowQuoting,
                AllowSparse = _columnInferenceResult.TextLoaderOptions.AllowSparse,
                Trainer = trainerMethod,
                GeneratedUsings = usings,
                Path = _settings.TrainDataset,
                TestPath = _settings.TestDataset,
                TaskType = _settings.MlTask.ToString(),
                Namespace = namespaceValue,
                LabelName = _settings.LabelName,
                CacheBeforeTrainer = cacheBeforeTrainer,
                Target = _settings.Target,
                HasOnnxModel = hasOnnxModel,
                MLNetModelName = _settings.ModelName,
            };
 
            return modelBuilder.TransformText();
        }
        #endregion
    }
}