File: Templates\Console\ModelBuilder.cs
Web Access
Project: src\src\Microsoft.ML.CodeGenerator\Microsoft.ML.CodeGenerator.csproj (Microsoft.ML.CodeGenerator)
// ------------------------------------------------------------------------------
// <auto-generated>
//     This code was generated by a tool.
//     Runtime Version: 16.0.0.0
//  
//     Changes to this file may cause incorrect behavior and will be lost if
//     the code is regenerated.
// </auto-generated>
// ------------------------------------------------------------------------------
namespace Microsoft.ML.CodeGenerator.Templates.Console
{
    using System.Linq;
    using System.Text;
    using System.Text.RegularExpressions;
    using System.Collections.Generic;
    using Microsoft.ML.CodeGenerator.Utilities;
    using System;
    
    /// <summary>
    /// Class to produce the template output
    /// </summary>
    [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "16.0.0.0")]
    internal partial class ModelBuilder : ModelBuilderBase
    {
        /// <summary>
        /// Create the template output
        /// </summary>
        public virtual string TransformText()
        {
if(Target == CSharp.GenerateTarget.Cli){ 
CLI_Annotation();
 } else if(Target == CSharp.GenerateTarget.ModelBuilder){ 
MB_Annotation();
 } 
            this.Write("\r\nusing System;\r\nusing System.Collections.Generic;\r\nusing System.IO;\r\nusing Syste" +
                    "m.Linq;\r\nusing Microsoft.ML;\r\nusing Microsoft.ML.Data;\r\nusing ");
            this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
            this.Write(".Model;\r\n");
            this.Write(this.ToStringHelper.ToStringWithCulture(GeneratedUsings));
            this.Write("\r\nnamespace ");
            this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
            this.Write(".ConsoleApp\r\n{\r\n    public static class ModelBuilder\r\n    {\r\n        private stat" +
                    "ic string TRAIN_DATA_FILEPATH = @\"");
            this.Write(this.ToStringHelper.ToStringWithCulture(Path));
            this.Write("\";\r\n");
if(!string.IsNullOrEmpty(TestPath)){ 
            this.Write("        private static string TEST_DATA_FILEPATH = @\"");
            this.Write(this.ToStringHelper.ToStringWithCulture(TestPath));
            this.Write("\";\r\n");
 } 
            this.Write(@"        private static string MODEL_FILE = ConsumeModel.MLNetModelPath;
 
        // Create MLContext to be shared across the model creation workflow objects 
        // Set a random seed for repeatable/deterministic results across multiple trainings.
        private static MLContext mlContext = new MLContext(seed: 1);
 
        public static void CreateModel()
        {
            // Load Data
            IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
                                            path: TRAIN_DATA_FILEPATH,
                                            hasHeader : ");
            this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));
            this.Write(",\r\n                                            separatorChar : \'");
            this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString())));
            this.Write("\',\r\n                                            allowQuoting : ");
            this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant()));
            this.Write(",\r\n                                            allowSparse: ");
            this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant()));
            this.Write(");\r\n\r\n");
 if(!string.IsNullOrEmpty(TestPath)){ 
            this.Write("            IDataView testDataView = mlContext.Data.LoadFromTextFile<ModelInput>(" +
                    "\r\n                                            path: TEST_DATA_FILEPATH,\r\n       " +
                    "                                     hasHeader : ");
            this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));
            this.Write(",\r\n                                            separatorChar : \'");
            this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString())));
            this.Write("\',\r\n                                            allowQuoting : ");
            this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant()));
            this.Write(",\r\n                                            allowSparse: ");
            this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant()));
            this.Write(");\r\n");
}
            this.Write(@"            // Build training pipeline
            IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);
 
            // Train Model
            ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
 
");
 if(string.IsNullOrEmpty(TestPath) && !HasOnnxModel){ 
            this.Write("            // Evaluate quality of Model\r\n            Evaluate(mlContext, trainin" +
                    "gDataView, trainingPipeline);\r\n\r\n");
}
 if(!string.IsNullOrEmpty(TestPath) && !HasOnnxModel){ 
            this.Write("            // Evaluate quality of Model\r\n            EvaluateModel(mlContext, ml" +
                    "Model, testDataView);\r\n\r\n");
}
            this.Write("            // Save model\r\n            SaveModel(mlContext, mlModel, MODEL_FILE, " +
                    "trainingDataView.Schema);\r\n        }\r\n\r\n        public static IEstimator<ITransf" +
                    "ormer> BuildTrainingPipeline(MLContext mlContext)\r\n        {\r\n");
 if(PreTrainerTransforms.Count >0 ) {
            this.Write("            // Data process configuration with pipeline data transformations \r\n  " +
                    "          var dataProcessPipeline = ");
 for(int i=0;i<PreTrainerTransforms.Count;i++) 
                                         { 
                                             if(i>0)
                                             { Write("\r\n                                      .Append(");
                                             }
                                             Write("mlContext.Transforms."+PreTrainerTransforms[i]);
                                             if(i>0)
                                             { Write(")");
                                             }
                                         }
                                      if(CacheBeforeTrainer){ 
                                           Write("\r\n                                      .AppendCacheCheckpoint(mlContext)");
                                           } 
            this.Write(";\r\n");
}
 if(Trainer != String.Empty ) {
            this.Write("            // Set the training algorithm \r\n            var trainer = mlContext.");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
if("Recommendation".Equals(TaskType)){ 
            this.Write("()");
 } 
            this.Write(".Trainers.");
            this.Write(this.ToStringHelper.ToStringWithCulture(Trainer));
 for(int i=0;i<PostTrainerTransforms.Count;i++) 
                                         { 
                                             Write("\r\n                                      .Append(");
                                             Write("mlContext.Transforms."+PostTrainerTransforms[i]);
                                             Write(")");
                                         }
            this.Write(";\r\n");
}
            this.Write("\r\n");
 if(PreTrainerTransforms.Count >0 && Trainer != String.Empty ) {
            this.Write("            var trainingPipeline = dataProcessPipeline.Append(trainer);\r\n");
 }
else if (PreTrainerTransforms.Count >0 && Trainer == String.Empty) {
            this.Write("\t\t\tvar trainingPipeline = dataProcessPipeline;\r\n");
 }
else{
            this.Write("            var trainingPipeline = trainer;\r\n");
}
            this.Write(@"
            return trainingPipeline;
        }
 
        public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
        {
            Console.WriteLine(""=============== Training  model ==============="");
 
            ITransformer model = trainingPipeline.Fit(trainingDataView);
 
            Console.WriteLine(""=============== End of training process ==============="");
            return model;
        }
 
");
 if(!string.IsNullOrEmpty(TestPath)){ 
            this.Write(@"        private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView)
        {
            // Evaluate the model and show accuracy stats
            Console.WriteLine(""===== Evaluating Model's accuracy with Test data ====="");
            IDataView predictions = mlModel.Transform(testDataView);
");
if("BinaryClassification".Equals(TaskType)){ 
            this.Write("            var metrics = mlContext.");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(".EvaluateNonCalibrated(predictions, \"");
            this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
            this.Write("\", \"Score\");\r\n            PrintBinaryClassificationMetrics(metrics);\r\n");
} if("MulticlassClassification".Equals(TaskType)){ 
            this.Write("            var metrics = mlContext.");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(".Evaluate(predictions, \"");
            this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
            this.Write("\", \"Score\");\r\n            PrintMulticlassClassificationMetrics(metrics);\r\n");
}if("Regression".Equals(TaskType) || "Recommendation".Equals(TaskType)){ 
            this.Write("            var metrics = mlContext.");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
if("Recommendation".Equals(TaskType)){ 
            this.Write("()");
 } 
            this.Write(".Evaluate(predictions, \"");
            this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
            this.Write("\", \"Score\");\r\n            PrintRegressionMetrics(metrics);\r\n");
}if("Ranking".Equals(TaskType)){ 
            this.Write("            var metrics = mlContext.");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(".Evaluate(predictions, \"");
            this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
            this.Write("\", \"Score\");\r\n            PrintRankingMetrics(metrics);\r\n");
}
            this.Write("        }\r\n");
}else{
            this.Write(@"        private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
        {
            // Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
            // in order to evaluate and get the model's accuracy metrics
            Console.WriteLine(""=============== Cross-validating to get model's accuracy metrics ==============="");
");
if("BinaryClassification".Equals(TaskType)){ 
            this.Write("            var crossValidationResults = mlContext.");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(".CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numberOfFolds: ");
            this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds));
            this.Write(", labelColumnName:\"");
            this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
            this.Write("\");\r\n            PrintBinaryClassificationFoldsAverageMetrics(crossValidationResu" +
                    "lts);\r\n");
}
if("MulticlassClassification".Equals(TaskType)){ 
            this.Write("            var crossValidationResults = mlContext.");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(".CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: ");
            this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds));
            this.Write(", labelColumnName:\"");
            this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
            this.Write("\");\r\n            PrintMulticlassClassificationFoldsAverageMetrics(crossValidation" +
                    "Results);\r\n");
}
if("Regression".Equals(TaskType)){ 
            this.Write("            var crossValidationResults = mlContext.");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(".CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: ");
            this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds));
            this.Write(", labelColumnName:\"");
            this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
            this.Write("\");\r\n            PrintRegressionFoldsAverageMetrics(crossValidationResults);\r\n");
}
if("Ranking".Equals(TaskType)){ 
            this.Write("            var crossValidationResults = mlContext.");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(".CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: ");
            this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds));
            this.Write(", labelColumnName:\"");
            this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
            this.Write("\");\r\n            PrintRankingFoldsAverageMetrics(crossValidationResults);\r\n");
}
            this.Write("        }\r\n");
}
            this.Write(@"
        private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
        {
            // Save/persist the trained model to a .ZIP file
            Console.WriteLine($""=============== Saving the model  ==============="");
            mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
            Console.WriteLine(""The model is saved to {0}"", GetAbsolutePath(modelRelativePath));
        }
 
        public static string GetAbsolutePath(string relativePath)
        {
            FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
            string assemblyFolderPath = _dataRoot.Directory.FullName;
 
            string fullPath = Path.Combine(assemblyFolderPath, relativePath);
 
            return fullPath;
        }
 
");
if("Regression".Equals(TaskType) || "Recommendation".Equals(TaskType)){ 
            this.Write("        public static void PrintRegressionMetrics(RegressionMetrics metrics)\r\n   " +
                    "     {\r\n            Console.WriteLine($\"****************************************" +
                    "*********\");\r\n            Console.WriteLine($\"*       Metrics for ");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(@" model      "");
            Console.WriteLine($""*------------------------------------------------"");
            Console.WriteLine($""*       LossFn:        {metrics.LossFunction:0.##}"");
            Console.WriteLine($""*       R2 Score:      {metrics.RSquared:0.##}"");
            Console.WriteLine($""*       Absolute loss: {metrics.MeanAbsoluteError:#.##}"");
            Console.WriteLine($""*       Squared loss:  {metrics.MeanSquaredError:#.##}"");
            Console.WriteLine($""*       RMS loss:      {metrics.RootMeanSquaredError:#.##}"");
            Console.WriteLine($""*************************************************"");
        }
 
        public static void PrintRegressionFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<RegressionMetrics>> crossValidationResults)
        {
            var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
            var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
            var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);
            var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
            var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);
 
            Console.WriteLine($""*************************************************************************************************************"");
            Console.WriteLine($""*       Metrics for ");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(@" model      "");
            Console.WriteLine($""*------------------------------------------------------------------------------------------------------------"");
            Console.WriteLine($""*       Average L1 Loss:       {L1.Average():0.###} "");
            Console.WriteLine($""*       Average L2 Loss:       {L2.Average():0.###}  "");
            Console.WriteLine($""*       Average RMS:           {RMS.Average():0.###}  "");
            Console.WriteLine($""*       Average Loss Function: {lossFunction.Average():0.###}  "");
            Console.WriteLine($""*       Average R-squared:     {R2.Average():0.###}  "");
            Console.WriteLine($""*************************************************************************************************************"");
        }
");
 } if("BinaryClassification".Equals(TaskType)){ 
            this.Write("        public static void PrintBinaryClassificationMetrics(BinaryClassificationM" +
                    "etrics metrics)\r\n        {\r\n            Console.WriteLine($\"********************" +
                    "****************************************\");\r\n            Console.WriteLine($\"*  " +
                    "     Metrics for binary classification model      \");\r\n            Console.Write" +
                    "Line($\"*-----------------------------------------------------------\");\r\n        " +
                    "    Console.WriteLine($\"*       Accuracy: {metrics.Accuracy:P2}\");\r\n            " +
                    "Console.WriteLine($\"*       Auc:      {metrics.AreaUnderRocCurve:P2}\");\r\n       " +
                    "     Console.WriteLine($\"*******************************************************" +
                    "*****\");\r\n        }\r\n\r\n\r\n        public static void PrintBinaryClassificationFol" +
                    "dsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<BinaryClassi" +
                    "ficationMetrics>> crossValResults)\r\n        {\r\n            var metricsInMultiple" +
                    "Folds = crossValResults.Select(r => r.Metrics);\r\n\r\n            var AccuracyValue" +
                    "s = metricsInMultipleFolds.Select(m => m.Accuracy);\r\n            var AccuracyAve" +
                    "rage = AccuracyValues.Average();\r\n            var AccuraciesStdDeviation = Calcu" +
                    "lateStandardDeviation(AccuracyValues);\r\n            var AccuraciesConfidenceInte" +
                    "rval95 = CalculateConfidenceInterval95(AccuracyValues);\r\n\r\n\r\n            Console" +
                    ".WriteLine($\"*******************************************************************" +
                    "******************************************\");\r\n            Console.WriteLine($\"*" +
                    "       Metrics for Binary Classification model      \");\r\n            Console.Wri" +
                    "teLine($\"*----------------------------------------------------------------------" +
                    "--------------------------------------\");\r\n            Console.WriteLine($\"*    " +
                    "   Average Accuracy:    {AccuracyAverage:0.###}  - Standard deviation: ({Accurac" +
                    "iesStdDeviation:#.###})  - Confidence Interval 95%: ({AccuraciesConfidenceInterv" +
                    "al95:#.###})\");\r\n            Console.WriteLine($\"*******************************" +
                    "******************************************************************************\")" +
                    ";\r\n        }\r\n\r\n        public static double CalculateStandardDeviation(IEnumera" +
                    "ble<double> values)\r\n        {\r\n            double average = values.Average();\r\n" +
                    "            double sumOfSquaresOfDifferences = values.Select(val => (val - avera" +
                    "ge) * (val - average)).Sum();\r\n            double standardDeviation = Math.Sqrt(" +
                    "sumOfSquaresOfDifferences / (values.Count() - 1));\r\n            return standardD" +
                    "eviation;\r\n        }\r\n\r\n        public static double CalculateConfidenceInterval" +
                    "95(IEnumerable<double> values)\r\n        {\r\n            double confidenceInterval" +
                    "95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1))" +
                    ";\r\n            return confidenceInterval95;\r\n        }\r\n");
} if("MulticlassClassification".Equals(TaskType)){
            this.Write("        public static void PrintMulticlassClassificationMetrics(MulticlassClassif" +
                    "icationMetrics metrics)\r\n        {\r\n            Console.WriteLine($\"************" +
                    "************************************************\");\r\n            Console.WriteLi" +
                    "ne($\"*    Metrics for multi-class classification model   \");\r\n            Consol" +
                    "e.WriteLine($\"*-----------------------------------------------------------\");\r\n " +
                    "           Console.WriteLine($\"    MacroAccuracy = {metrics.MacroAccuracy:0.####" +
                    "}, a value between 0 and 1, the closer to 1, the better\");\r\n            Console." +
                    "WriteLine($\"    MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between " +
                    "0 and 1, the closer to 1, the better\");\r\n            Console.WriteLine($\"    Log" +
                    "Loss = {metrics.LogLoss:0.####}, the closer to 0, the better\");\r\n            for" +
                    " (int i = 0; i < metrics.PerClassLogLoss.Count; i++)\r\n            {\r\n           " +
                    "     Console.WriteLine($\"    LogLoss for class {i + 1} = {metrics.PerClassLogLos" +
                    "s[i]:0.####}, the closer to 0, the better\");\r\n            }\r\n            Console" +
                    ".WriteLine($\"************************************************************\");\r\n  " +
                    "      }\r\n\r\n        public static void PrintMulticlassClassificationFoldsAverageM" +
                    "etrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificati" +
                    "onMetrics>> crossValResults)\r\n        {\r\n            var metricsInMultipleFolds " +
                    "= crossValResults.Select(r => r.Metrics);\r\n\r\n            var microAccuracyValues" +
                    " = metricsInMultipleFolds.Select(m => m.MicroAccuracy);\r\n            var microAc" +
                    "curacyAverage = microAccuracyValues.Average();\r\n            var microAccuraciesS" +
                    "tdDeviation = CalculateStandardDeviation(microAccuracyValues);\r\n            var " +
                    "microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccurac" +
                    "yValues);\r\n\r\n            var macroAccuracyValues = metricsInMultipleFolds.Select" +
                    "(m => m.MacroAccuracy);\r\n            var macroAccuracyAverage = macroAccuracyVal" +
                    "ues.Average();\r\n            var macroAccuraciesStdDeviation = CalculateStandardD" +
                    "eviation(macroAccuracyValues);\r\n            var macroAccuraciesConfidenceInterva" +
                    "l95 = CalculateConfidenceInterval95(macroAccuracyValues);\r\n\r\n            var log" +
                    "LossValues = metricsInMultipleFolds.Select(m => m.LogLoss);\r\n            var log" +
                    "LossAverage = logLossValues.Average();\r\n            var logLossStdDeviation = Ca" +
                    "lculateStandardDeviation(logLossValues);\r\n            var logLossConfidenceInter" +
                    "val95 = CalculateConfidenceInterval95(logLossValues);\r\n\r\n            var logLoss" +
                    "ReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);\r\n     " +
                    "       var logLossReductionAverage = logLossReductionValues.Average();\r\n        " +
                    "    var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReducti" +
                    "onValues);\r\n            var logLossReductionConfidenceInterval95 = CalculateConf" +
                    "idenceInterval95(logLossReductionValues);\r\n\r\n            Console.WriteLine($\"***" +
                    "********************************************************************************" +
                    "**************************\");\r\n            Console.WriteLine($\"*       Metrics f" +
                    "or Multi-class Classification model      \");\r\n            Console.WriteLine($\"*-" +
                    "--------------------------------------------------------------------------------" +
                    "---------------------------\");\r\n            Console.WriteLine($\"*       Average " +
                    "MicroAccuracy:    {microAccuracyAverage:0.###}  - Standard deviation: ({microAcc" +
                    "uraciesStdDeviation:#.###})  - Confidence Interval 95%: ({microAccuraciesConfide" +
                    "nceInterval95:#.###})\");\r\n            Console.WriteLine($\"*       Average MacroA" +
                    "ccuracy:    {macroAccuracyAverage:0.###}  - Standard deviation: ({macroAccuracie" +
                    "sStdDeviation:#.###})  - Confidence Interval 95%: ({macroAccuraciesConfidenceInt" +
                    "erval95:#.###})\");\r\n            Console.WriteLine($\"*       Average LogLoss:    " +
                    "      {logLossAverage:#.###}  - Standard deviation: ({logLossStdDeviation:#.###}" +
                    ")  - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})\");\r\n        " +
                    "    Console.WriteLine($\"*       Average LogLossReduction: {logLossReductionAvera" +
                    "ge:#.###}  - Standard deviation: ({logLossReductionStdDeviation:#.###})  - Confi" +
                    "dence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})\");\r\n         " +
                    "   Console.WriteLine($\"*********************************************************" +
                    "****************************************************\");\r\n\r\n        }\r\n\r\n        " +
                    "public static double CalculateStandardDeviation(IEnumerable<double> values)\r\n   " +
                    "     {\r\n            double average = values.Average();\r\n            double sumOf" +
                    "SquaresOfDifferences = values.Select(val => (val - average) * (val - average)).S" +
                    "um();\r\n            double standardDeviation = Math.Sqrt(sumOfSquaresOfDifference" +
                    "s / (values.Count() - 1));\r\n            return standardDeviation;\r\n        }\r\n\r\n" +
                    "        public static double CalculateConfidenceInterval95(IEnumerable<double> v" +
                    "alues)\r\n        {\r\n            double confidenceInterval95 = 1.96 * CalculateSta" +
                    "ndardDeviation(values) / Math.Sqrt((values.Count() - 1));\r\n            return co" +
                    "nfidenceInterval95;\r\n        }\r\n");
} if("Ranking".Equals(TaskType)){ 
            this.Write("        public static void PrintRankingMetrics(RankingMetrics metrics)\r\n        {" +
                    "\r\n            Console.WriteLine($\"**********************************************" +
                    "***\");\r\n            Console.WriteLine($\"*       Metrics for ");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(@" model      "");
            var max = (metrics.NormalizedDiscountedCumulativeGains.Count < 10) ? metrics.NormalizedDiscountedCumulativeGains.Count-1 : 9;
            Console.WriteLine($""*------------------------------------------------"");
            Console.WriteLine($""*       Normalized Discounted Cumulative Gains @10:        {metrics.NormalizedDiscountedCumulativeGains[max]:0.##}"");
            Console.WriteLine($""*       Discounted Cumulative Gains @10:      {metrics.DiscountedCumulativeGains[max]:#.##}"");
            Console.WriteLine($""*************************************************"");
        }
 
        public static void PrintRankingFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<RankingMetrics>> crossValidationResults)
        {
            var max = (crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count < 10) ? crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count-1 : 9;
            var NDCG = crossValidationResults.Select(r => r.Metrics.NormalizedDiscountedCumulativeGains[max]);
            var DCG = crossValidationResults.Select(r => r.Metrics.DiscountedCumulativeGains[max]);
            Console.WriteLine($""*************************************************************************************************************"");
            Console.WriteLine($""*       Metrics for ");
            this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
            this.Write(@" model      "");
            Console.WriteLine($""*------------------------------------------------------------------------------------------------------------"");
            Console.WriteLine($""*       Average Normalized Discounted Cumulative Gains @10:       {NDCG.Average():0.###}"");
            Console.WriteLine($""*       Average Discounted Cumulative Gains @10:       {DCG.Average():#.###}"");
            Console.WriteLine($""*************************************************************************************************************"");
        }
");
 }
            this.Write("    }\r\n}\r\n");
            return this.GenerationEnvironment.ToString();
        }
 
public string Path {get;set;}
public string TestPath {get;set;}
public bool HasHeader {get;set;}
public char Separator {get;set;}
public IList<string> PreTrainerTransforms {get;set;}
public string Trainer {get;set;}
public string TaskType {get;set;}
public string GeneratedUsings {get;set;}
public bool AllowQuoting {get;set;}
public bool AllowSparse {get;set;}
public int Kfolds {get;set;} = 5;
public string Namespace {get;set;}
public string LabelName {get;set;}
public bool CacheBeforeTrainer {get;set;}
public IList<string> PostTrainerTransforms {get;set;}
internal CSharp.GenerateTarget Target {get;set;}
public bool HasOnnxModel {get;set;} = false;
public string MLNetModelName {get; set;}
 
 
void CLI_Annotation()
{
this.Write(@"//*****************************************************************************************
//*                                                                                       *
//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. *
//*                                                                                       *
//*****************************************************************************************
");
 
 
}
 
 
void MB_Annotation()
{
this.Write("// This file was auto-generated by ML.NET Model Builder. \r\n");
 
 
}
 
    }
    #region Base class
    /// <summary>
    /// Base class for this transformation
    /// </summary>
    [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "16.0.0.0")]
    internal class ModelBuilderBase
    {
        #region Fields
        private global::System.Text.StringBuilder generationEnvironmentField;
        private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField;
        private global::System.Collections.Generic.List<int> indentLengthsField;
        private string currentIndentField = "";
        private bool endsWithNewline;
        private global::System.Collections.Generic.IDictionary<string, object> sessionField;
        #endregion
        #region Properties
        /// <summary>
        /// The string builder that generation-time code is using to assemble generated output
        /// </summary>
        protected System.Text.StringBuilder GenerationEnvironment
        {
            get
            {
                if ((this.generationEnvironmentField == null))
                {
                    this.generationEnvironmentField = new global::System.Text.StringBuilder();
                }
                return this.generationEnvironmentField;
            }
            set
            {
                this.generationEnvironmentField = value;
            }
        }
        /// <summary>
        /// The error collection for the generation process
        /// </summary>
        public System.CodeDom.Compiler.CompilerErrorCollection Errors
        {
            get
            {
                if ((this.errorsField == null))
                {
                    this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection();
                }
                return this.errorsField;
            }
        }
        /// <summary>
        /// A list of the lengths of each indent that was added with PushIndent
        /// </summary>
        private System.Collections.Generic.List<int> indentLengths
        {
            get
            {
                if ((this.indentLengthsField == null))
                {
                    this.indentLengthsField = new global::System.Collections.Generic.List<int>();
                }
                return this.indentLengthsField;
            }
        }
        /// <summary>
        /// Gets the current indent we use when adding lines to the output
        /// </summary>
        public string CurrentIndent
        {
            get
            {
                return this.currentIndentField;
            }
        }
        /// <summary>
        /// Current transformation session
        /// </summary>
        public virtual global::System.Collections.Generic.IDictionary<string, object> Session
        {
            get
            {
                return this.sessionField;
            }
            set
            {
                this.sessionField = value;
            }
        }
        #endregion
        #region Transform-time helpers
        /// <summary>
        /// Write text directly into the generated output
        /// </summary>
        public void Write(string textToAppend)
        {
            if (string.IsNullOrEmpty(textToAppend))
            {
                return;
            }
            // If we're starting off, or if the previous text ended with a newline,
            // we have to append the current indent first.
            if (((this.GenerationEnvironment.Length == 0) 
                        || this.endsWithNewline))
            {
                this.GenerationEnvironment.Append(this.currentIndentField);
                this.endsWithNewline = false;
            }
            // Check if the current text ends with a newline
            if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture))
            {
                this.endsWithNewline = true;
            }
            // This is an optimization. If the current indent is "", then we don't have to do any
            // of the more complex stuff further down.
            if ((this.currentIndentField.Length == 0))
            {
                this.GenerationEnvironment.Append(textToAppend);
                return;
            }
            // Everywhere there is a newline in the text, add an indent after it
            textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField));
            // If the text ends with a newline, then we should strip off the indent added at the very end
            // because the appropriate indent will be added when the next time Write() is called
            if (this.endsWithNewline)
            {
                this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length));
            }
            else
            {
                this.GenerationEnvironment.Append(textToAppend);
            }
        }
        /// <summary>
        /// Write text directly into the generated output
        /// </summary>
        public void WriteLine(string textToAppend)
        {
            this.Write(textToAppend);
            this.GenerationEnvironment.AppendLine();
            this.endsWithNewline = true;
        }
        /// <summary>
        /// Write formatted text directly into the generated output
        /// </summary>
        public void Write(string format, params object[] args)
        {
            this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
        }
        /// <summary>
        /// Write formatted text directly into the generated output
        /// </summary>
        public void WriteLine(string format, params object[] args)
        {
            this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
        }
        /// <summary>
        /// Raise an error
        /// </summary>
        public void Error(string message)
        {
            System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
            error.ErrorText = message;
            this.Errors.Add(error);
        }
        /// <summary>
        /// Raise a warning
        /// </summary>
        public void Warning(string message)
        {
            System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
            error.ErrorText = message;
            error.IsWarning = true;
            this.Errors.Add(error);
        }
        /// <summary>
        /// Increase the indent
        /// </summary>
        public void PushIndent(string indent)
        {
            if ((indent == null))
            {
                throw new global::System.ArgumentNullException("indent");
            }
            this.currentIndentField = (this.currentIndentField + indent);
            this.indentLengths.Add(indent.Length);
        }
        /// <summary>
        /// Remove the last indent that was added with PushIndent
        /// </summary>
        public string PopIndent()
        {
            string returnValue = "";
            if ((this.indentLengths.Count > 0))
            {
                int indentLength = this.indentLengths[(this.indentLengths.Count - 1)];
                this.indentLengths.RemoveAt((this.indentLengths.Count - 1));
                if ((indentLength > 0))
                {
                    returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength));
                    this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength));
                }
            }
            return returnValue;
        }
        /// <summary>
        /// Remove any indentation
        /// </summary>
        public void ClearIndent()
        {
            this.indentLengths.Clear();
            this.currentIndentField = "";
        }
        #endregion
        #region ToString Helpers
        /// <summary>
        /// Utility class to produce culture-oriented representation of an object as a string.
        /// </summary>
        public class ToStringInstanceHelper
        {
            private System.IFormatProvider formatProviderField  = global::System.Globalization.CultureInfo.InvariantCulture;
            /// <summary>
            /// Gets or sets format provider to be used by ToStringWithCulture method.
            /// </summary>
            public System.IFormatProvider FormatProvider
            {
                get
                {
                    return this.formatProviderField ;
                }
                set
                {
                    if ((value != null))
                    {
                        this.formatProviderField  = value;
                    }
                }
            }
            /// <summary>
            /// This is called from the compile/run appdomain to convert objects within an expression block to a string
            /// </summary>
            public string ToStringWithCulture(object objectToConvert)
            {
                if ((objectToConvert == null))
                {
                    throw new global::System.ArgumentNullException("objectToConvert");
                }
                System.Type t = objectToConvert.GetType();
                System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] {
                            typeof(System.IFormatProvider)});
                if ((method == null))
                {
                    return objectToConvert.ToString();
                }
                else
                {
                    return ((string)(method.Invoke(objectToConvert, new object[] {
                                this.formatProviderField })));
                }
            }
        }
        private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper();
        /// <summary>
        /// Helper to produce culture-oriented representation of an object as a string
        /// </summary>
        public ToStringInstanceHelper ToStringHelper
        {
            get
            {
                return this.toStringHelperField;
            }
        }
        #endregion
    }
    #endregion
}