File: Program.cs
Web Access
Project: src\docs\samples\Microsoft.ML.Samples.OneDal\Microsoft.ML.Samples.OneDal.csproj (Microsoft.ML.Samples.OneDal)
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.FastTree;
using Newtonsoft.Json;
 
namespace Microsoft.ML.Samples.OneDal
{
    class Program
    {
        public static IDataView[] LoadData(
            MLContext mlContext, string trainingFile, string testingFile,
            string task, string label = "target", char separator = ',')
        {
            List<IDataView> dataList = new List<IDataView>();
            System.IO.StreamReader file = new System.IO.StreamReader(trainingFile);
            string header = file.ReadLine();
            file.Close();
            string[] headerArray = header.Split(separator);
            List<TextLoader.Column> columns = new List<TextLoader.Column>();
            foreach (string column in headerArray)
            {
                if (column == label)
                {
                    if (task == "binary")
                        columns.Add(new TextLoader.Column(column, DataKind.Boolean, Array.IndexOf(headerArray, column)));
                    else
                        columns.Add(new TextLoader.Column(column, DataKind.Single, Array.IndexOf(headerArray, column)));
                }
                else
                {
                    columns.Add(new TextLoader.Column(column, DataKind.Single, Array.IndexOf(headerArray, column)));
                }
            }
 
            var loader = mlContext.Data.CreateTextLoader(
                separatorChar: separator,
                hasHeader: true,
                columns: columns.ToArray()
            );
            dataList.Add(loader.Load(trainingFile));
            dataList.Add(loader.Load(testingFile));
            return dataList.ToArray();
        }
 
        public static string[] GetFeaturesArray(IDataView data, string labelName = "target")
        {
            List<string> featuresList = new List<string>();
            var nColumns = data.Schema.Count;
            var columnsEnumerator = data.Schema.GetEnumerator();
            for (int i = 0; i < nColumns; i++)
            {
                columnsEnumerator.MoveNext();
                if (columnsEnumerator.Current.Name != labelName)
                    featuresList.Add(columnsEnumerator.Current.Name);
            }
 
            return featuresList.ToArray();
        }
 
        public static double[] RunRandomForestClassification(MLContext mlContext, IDataView trainingData, IDataView testingData, string labelName, int numberOfTrees, int numberOfLeaves)
        {
            var featuresArray = GetFeaturesArray(trainingData, labelName);
            var preprocessingPipeline = mlContext.Transforms.Concatenate("Features", featuresArray);
            var preprocessedTrainingData = preprocessingPipeline.Fit(trainingData).Transform(trainingData);
            var preprocessedTestingData = preprocessingPipeline.Fit(trainingData).Transform(testingData);
 
            FastForestBinaryTrainer.Options options = new FastForestBinaryTrainer.Options();
            options.LabelColumnName = labelName;
            options.FeatureColumnName = "Features";
            options.NumberOfTrees = numberOfTrees;
            options.NumberOfLeaves = numberOfLeaves;
            options.MinimumExampleCountPerLeaf = 5;
            options.FeatureFraction = 1.0;
 
            var trainer = mlContext.BinaryClassification.Trainers.FastForest(options);
 
            ITransformer model = trainer.Fit(preprocessedTrainingData);
 
            IDataView trainingPredictions = model.Transform(preprocessedTrainingData);
            var trainingMetrics = mlContext.BinaryClassification.EvaluateNonCalibrated(trainingPredictions, labelColumnName: labelName);
            IDataView testingPredictions = model.Transform(preprocessedTestingData);
            var testingMetrics = mlContext.BinaryClassification.EvaluateNonCalibrated(testingPredictions, labelColumnName: labelName);
 
            double[] metrics = new double[4];
            metrics[0] = trainingMetrics.Accuracy;
            metrics[1] = testingMetrics.Accuracy;
            metrics[2] = trainingMetrics.F1Score;
            metrics[3] = testingMetrics.F1Score;
            return metrics;
        }
 
        public static double[] RunRandomForestRegression(MLContext mlContext, IDataView trainingData, IDataView testingData, string labelName, int numberOfTrees, int numberOfLeaves)
        {
            var featuresArray = GetFeaturesArray(trainingData, labelName);
            var preprocessingPipeline = mlContext.Transforms.Concatenate("Features", featuresArray);
            var preprocessedTrainingData = preprocessingPipeline.Fit(trainingData).Transform(trainingData);
            var preprocessedTestingData = preprocessingPipeline.Fit(trainingData).Transform(testingData);
 
            FastForestRegressionTrainer.Options options = new FastForestRegressionTrainer.Options();
            options.LabelColumnName = labelName;
            options.FeatureColumnName = "Features";
            options.NumberOfTrees = numberOfTrees;
            options.NumberOfLeaves = numberOfLeaves;
            options.MinimumExampleCountPerLeaf = 5;
            options.FeatureFraction = 1.0;
 
            var trainer = mlContext.Regression.Trainers.FastForest(options);
 
            ITransformer model = trainer.Fit(preprocessedTrainingData);
 
            IDataView trainingPredictions = model.Transform(preprocessedTrainingData);
            var trainingMetrics = mlContext.Regression.Evaluate(trainingPredictions, labelColumnName: labelName);
            IDataView testingPredictions = model.Transform(preprocessedTestingData);
            var testingMetrics = mlContext.Regression.Evaluate(testingPredictions, labelColumnName: labelName);
 
            double[] metrics = new double[4];
            metrics[0] = trainingMetrics.RootMeanSquaredError;
            metrics[1] = testingMetrics.RootMeanSquaredError;
            metrics[2] = trainingMetrics.RSquared;
            metrics[3] = testingMetrics.RSquared;
            return metrics;
        }
 
        public static double[] RunOLSRegression(MLContext mlContext, IDataView trainingData, IDataView testingData, string labelName)
        {
            var featuresArray = GetFeaturesArray(trainingData, labelName);
            var preprocessingPipeline = mlContext.Transforms.Concatenate("Features", featuresArray);
            var preprocessedTrainingData = preprocessingPipeline.Fit(trainingData).Transform(trainingData);
            var preprocessedTestingData = preprocessingPipeline.Fit(trainingData).Transform(testingData);
 
            OlsTrainer.Options options = new OlsTrainer.Options();
            options.LabelColumnName = labelName;
            options.FeatureColumnName = "Features";
 
            var trainer = mlContext.Regression.Trainers.Ols(options);
 
            ITransformer model = trainer.Fit(preprocessedTrainingData);
 
            IDataView trainingPredictions = model.Transform(preprocessedTrainingData);
            var trainingMetrics = mlContext.Regression.Evaluate(trainingPredictions, labelColumnName: labelName);
            IDataView testingPredictions = model.Transform(preprocessedTestingData);
            var testingMetrics = mlContext.Regression.Evaluate(testingPredictions, labelColumnName: labelName);
 
            double[] metrics = new double[4];
            metrics[0] = trainingMetrics.RootMeanSquaredError;
            metrics[1] = testingMetrics.RootMeanSquaredError;
            metrics[2] = trainingMetrics.RSquared;
            metrics[3] = testingMetrics.RSquared;
            return metrics;
        }
 
        static void Main(string[] args)
        {
            // args[0] - training data filename
            // args[1] - testing data filename
            // args[2] - machine learning task (regression, binary)
            // args[3] - machine learning algorithm (RandomForest, OLS)
            // Random Forest parameters:
            //     args[4] - NumberOfTrees
            //     args[5] - NumberOfLeaves
            var mlContext = new MLContext(seed: 42);
            // data[0] - training subset
            // data[1] - testing subset
            IDataView[] data = LoadData(mlContext, args[0], args[1], args[2]);
            string labelName = "target";
 
            var mainWatch = System.Diagnostics.Stopwatch.StartNew();
            double[] metrics;
            if (args[3] == "RandomForest")
            {
                int numberOfTrees = Int32.Parse(args[4]);
                int numberOfLeaves = Int32.Parse(args[5]);
                if (args[2] == "binary")
                {
 
                    metrics = RunRandomForestClassification(mlContext, data[0], data[1], labelName, numberOfTrees, numberOfLeaves);
                    mainWatch.Stop();
                    Console.WriteLine("algorithm,all workflow time[ms],training accuracy,testing accuracy,training F1 score,testing F1 score");
                    Console.WriteLine($"Random Forest Binary,{mainWatch.Elapsed.TotalMilliseconds},{metrics[0]},{metrics[1]},{metrics[2]},{metrics[3]}");
                }
                else
                {
                    metrics = RunRandomForestRegression(mlContext, data[0], data[1], labelName, numberOfTrees, numberOfLeaves);
                    mainWatch.Stop();
                    Console.WriteLine("algorithm,all workflow time[ms],training RMSE,testing RMSE,training R2 score,testing R2 score");
                    Console.WriteLine($"Random Forest Regression,{mainWatch.Elapsed.TotalMilliseconds},{metrics[0]},{metrics[1]},{metrics[2]},{metrics[3]}");
                }
            }
            else if (args[3] == "OLS")
            {
                metrics = RunOLSRegression(mlContext, data[0], data[1], labelName);
                mainWatch.Stop();
                Console.WriteLine("algorithm,all workflow time[ms],training RMSE,testing RMSE,training R2 score,testing R2 score");
                Console.WriteLine($"OLS Regression,{mainWatch.Elapsed.TotalMilliseconds},{metrics[0]},{metrics[1]},{metrics[2]},{metrics[3]}");
            }
        }
    }
}