File: PredictionEngineBench.cs
Web Access
Project: src\test\Microsoft.ML.PerformanceTests\Microsoft.ML.PerformanceTests.csproj (Microsoft.ML.PerformanceTests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using BenchmarkDotNet.Attributes;
using Microsoft.ML.Data;
using Microsoft.ML.PerformanceTests.Harness;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
 
namespace Microsoft.ML.PerformanceTests
{
    [CIBenchmark]
    public class PredictionEngineBench : BenchmarkBase
    {
        private IrisData _irisExample;
        private PredictionEngine<IrisData, IrisPrediction> _irisModel;
 
        private SentimentData _sentimentExample;
        private PredictionEngine<SentimentData, SentimentPrediction> _sentimentModel;
 
        private BreastCancerData _breastCancerExample;
        private PredictionEngine<BreastCancerData, BreastCancerPrediction> _breastCancerModel;
 
        [GlobalSetup(Target = nameof(MakeIrisPredictions))]
        public void SetupIrisPipeline()
        {
            _irisExample = new IrisData()
            {
                SepalLength = 3.3f,
                SepalWidth = 1.6f,
                PetalLength = 0.2f,
                PetalWidth = 5.1f,
            };
 
            string irisDataPath = GetBenchmarkDataPathAndEnsureData("iris.txt");
 
            var env = new MLContext(seed: 1);
 
            // Create text loader.
            var options = new TextLoader.Options()
            {
                Columns = new[]
                {
                    new TextLoader.Column("Label", DataKind.Single, 0),
                    new TextLoader.Column("SepalLength", DataKind.Single, 1),
                    new TextLoader.Column("SepalWidth", DataKind.Single, 2),
                    new TextLoader.Column("PetalLength", DataKind.Single, 3),
                    new TextLoader.Column("PetalWidth", DataKind.Single, 4),
                },
                HasHeader = true,
            };
            var loader = new TextLoader(env, options: options);
 
            IDataView data = loader.Load(irisDataPath);
 
            var pipeline = new ColumnConcatenatingEstimator(env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" })
                .Append(env.Transforms.Conversion.MapValueToKey("Label"))
                .Append(env.MulticlassClassification.Trainers.SdcaMaximumEntropy(
                    new SdcaMaximumEntropyMulticlassTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, }));
 
            var model = pipeline.Fit(data);
 
            _irisModel = env.Model.CreatePredictionEngine<IrisData, IrisPrediction>(model);
        }
 
        [GlobalSetup(Target = nameof(MakeSentimentPredictions))]
        public void SetupSentimentPipeline()
        {
            _sentimentExample = new SentimentData()
            {
                SentimentText = "Not a big fan of this."
            };
 
            string sentimentDataPath = GetBenchmarkDataPathAndEnsureData("wikipedia-detox-250-line-data.tsv");
 
            var mlContext = new MLContext(seed: 1);
 
            // Create text loader.
            var options = new TextLoader.Options()
            {
                Columns = new[]
                {
                    new TextLoader.Column("Label", DataKind.Boolean, 0),
                    new TextLoader.Column("SentimentText", DataKind.String, 1)
                },
                HasHeader = true,
            };
            var loader = new TextLoader(mlContext, options: options);
 
            IDataView data = loader.Load(sentimentDataPath);
 
            var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
                .Append(mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(
                    new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, }));
 
            var model = pipeline.Fit(data);
 
            _sentimentModel = mlContext.Model.CreatePredictionEngine<SentimentData, SentimentPrediction>(model);
        }
 
        [GlobalSetup(Target = nameof(MakeBreastCancerPredictions))]
        public void SetupBreastCancerPipeline()
        {
            _breastCancerExample = new BreastCancerData()
            {
                Features = new[] { 5f, 1f, 1f, 1f, 2f, 1f, 3f, 1f, 1f }
            };
 
            string breastCancerDataPath = GetBenchmarkDataPathAndEnsureData("breast-cancer.txt");
 
            var env = new MLContext(seed: 1);
 
            // Create text loader.
            var options = new TextLoader.Options()
            {
                Columns = new[]
                {
                    new TextLoader.Column("Label", DataKind.Boolean, 0),
                    new TextLoader.Column("Features", DataKind.Single, new[] { new TextLoader.Range(1, 9) })
                },
                HasHeader = false,
            };
            var loader = new TextLoader(env, options: options);
 
            IDataView data = loader.Load(breastCancerDataPath);
 
            var pipeline = env.BinaryClassification.Trainers.SdcaNonCalibrated(
                new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, });
 
            var model = pipeline.Fit(data);
 
            _breastCancerModel = env.Model.CreatePredictionEngine<BreastCancerData, BreastCancerPrediction>(model);
        }
 
        [Benchmark]
        public void MakeIrisPredictions()
        {
            for (int i = 0; i < 10000; i++)
            {
                _irisModel.Predict(_irisExample);
            }
        }
 
        [Benchmark]
        public void MakeSentimentPredictions()
        {
            for (int i = 0; i < 10000; i++)
            {
                _sentimentModel.Predict(_sentimentExample);
            }
        }
 
        [Benchmark]
        public void MakeBreastCancerPredictions()
        {
            for (int i = 0; i < 10000; i++)
            {
                _breastCancerModel.Predict(_breastCancerExample);
            }
        }
    }
 
    public class SentimentData
    {
        [ColumnName("Label"), LoadColumn(0)]
        public bool Sentiment;
 
        [LoadColumn(1)]
        public string SentimentText;
    }
 
    public class SentimentPrediction
    {
        [ColumnName("PredictedLabel")]
        public bool Sentiment;
 
        public float Score;
    }
 
    public class BreastCancerData
    {
        [ColumnName("Label")]
        public bool Label;
 
        [ColumnName("Features"), VectorType(9)]
        public float[] Features;
    }
 
    public class BreastCancerPrediction
    {
        [ColumnName("Score")]
        public float Score;
    }
}