File: Prediction.cs
Web Access
Project: src\test\Microsoft.ML.IntegrationTests\Microsoft.ML.IntegrationTests.csproj (Microsoft.ML.IntegrationTests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Reflection;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.IntegrationTests.Datasets;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Trainers;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.IntegrationTests
{
    public class PredictionScenarios : IntegrationTestBaseClass
    {
        public PredictionScenarios(ITestOutputHelper output) : base(output)
        {
        }
 
        class Prediction
        {
            public float Score { get; set; }
            public bool PredictedLabel { get; set; }
        }
        /// <summary>
        /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
        /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
        /// and configures the scorer (or more precisely instantiates a new scorer over the same model parameters)
        /// with some threshold derived from that.
        /// </summary>
        [Fact]
        public void ReconfigurablePrediction()
        {
            var mlContext = new MLContext(seed: 1);
 
            var options = new TextLoader.Options
            {
                HasHeader = TestDatasets.Sentiment.fileHasHeader,
                Separators = new[] { TestDatasets.Sentiment.fileSeparator }
            };
 
            var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(TestCommon.GetDataPath(DataDir, TestDatasets.Sentiment.trainFilename),
                options);
 
            // Create a training pipeline.
            var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
                .AppendCacheCheckpoint(mlContext)
                .Append(mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(
                    new LbfgsLogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }));
 
            // Train the model.
            var model = pipeline.Fit(data);
            var engine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(model);
            var pr = engine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
            // Score is 0.64 so predicted label is true.
            Assert.True(pr.PredictedLabel);
            Assert.True(pr.Score > 0);
            var transformers = new List<ITransformer>();
            foreach (var transform in model)
            {
                if (transform != model.LastTransformer)
                    transformers.Add(transform);
            }
            transformers.Add(mlContext.BinaryClassification.ChangeModelThreshold(model.LastTransformer, 0.7f));
            var newModel = new TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>>(transformers.ToArray());
            var newEngine = mlContext.Model.CreatePredictionEngine<TweetSentiment, Prediction>(newModel);
            pr = newEngine.Predict(new TweetSentiment() { SentimentText = "Good Bad job" });
            // Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false.
 
            Assert.False(pr.PredictedLabel);
            Assert.False(pr.Score > 0.7);
        }
 
        [Fact]
        public void ReconfigurablePredictionNoPipeline()
        {
            var mlContext = new MLContext(seed: 1);
            var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
            var pipeline = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(
                     new Trainers.LbfgsLogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 });
            var model = pipeline.Fit(data);
            var newModel = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f);
            var rnd = new Random(1);
            var randomDataPoint = TypeTestData.GetRandomInstance(rnd);
            var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model);
            var pr = engine.Predict(randomDataPoint);
            // Score is -1.38 so predicted label is false.
            Assert.False(pr.PredictedLabel);
            Assert.True(pr.Score <= 0);
            var newEngine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(newModel);
            pr = newEngine.Predict(randomDataPoint);
            // Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true.
            Assert.True(pr.PredictedLabel);
            Assert.True(pr.Score <= 0);
        }
 
        [Fact]
        public void PredictionEngineModelDisposal()
        {
            var mlContext = new MLContext(seed: 1);
            var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
            var pipeline = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(
                     new Trainers.LbfgsLogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 });
            var model = pipeline.Fit(data);
 
            var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model, new PredictionEngineOptions());
 
            // Dispose of prediction engine, should dispose of model
            engine.Dispose();
 
            // Get disposed flag using reflection
            var bfIsDisposed = BindingFlags.Instance | BindingFlags.NonPublic;
            var field = model.GetType().BaseType.BaseType.GetField("_disposed", bfIsDisposed);
 
            // Make sure the model is actually disposed
            Assert.True((bool)field.GetValue(model));
 
            // Make a new model/prediction engine. Set the options so prediction engine doesn't dispose
            model = pipeline.Fit(data);
 
            var options = new PredictionEngineOptions()
            {
                OwnsTransformer = false
            };
 
            engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model, options);
 
            // Dispose of prediction engine, shouldn't dispose of model
            engine.Dispose();
 
            // Make sure model is not disposed of.
            Assert.False((bool)field.GetValue(model));
 
            // Dispose of the model for test cleanliness
            model.Dispose();
        }
    }
}