File: TrainerEstimators\CalibratorEstimators.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.IO;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Xunit;
 
namespace Microsoft.ML.Tests.TrainerEstimators
{
    public partial class TrainerEstimators
    {
        /// <summary>
        /// OVA and calibrators
        /// </summary>
        [Fact]
        public void PlattCalibratorEstimator()
        {
            var calibratorTestData = GetCalibratorTestData();
 
            // plattCalibrator
            var plattCalibratorEstimator = new PlattCalibratorEstimator(Env);
            var plattCalibratorTransformer = plattCalibratorEstimator.Fit(calibratorTestData.ScoredData);
 
            //testData
            CheckValidCalibratedData(calibratorTestData.ScoredData, plattCalibratorTransformer);
 
            //test estimator
            TestEstimatorCore(plattCalibratorEstimator, calibratorTestData.ScoredData);
 
            Done();
        }
 
        /// <summary>
        /// OVA and calibrators
        /// </summary>
        [Fact]
        public void FixedPlattCalibratorEstimator()
        {
            var calibratorTestData = GetCalibratorTestData();
 
            // fixedPlattCalibrator
            var fixedPlattCalibratorEstimator = new FixedPlattCalibratorEstimator(Env);
            var fixedPlattCalibratorTransformer = fixedPlattCalibratorEstimator.Fit(calibratorTestData.ScoredData);
 
            CheckValidCalibratedData(calibratorTestData.ScoredData, fixedPlattCalibratorTransformer);
 
            //test estimator
            TestEstimatorCore(fixedPlattCalibratorEstimator, calibratorTestData.ScoredData);
 
            Done();
        }
 
        /// <summary>
        /// OVA and calibrators
        /// </summary>
        [Fact]
        public void NaiveCalibratorEstimator()
        {
            var calibratorTestData = GetCalibratorTestData();
 
            // naive calibrator
            var naiveCalibratorEstimator = new NaiveCalibratorEstimator(Env);
            var naiveCalibratorTransformer = naiveCalibratorEstimator.Fit(calibratorTestData.ScoredData);
 
            // check data
            CheckValidCalibratedData(calibratorTestData.ScoredData, naiveCalibratorTransformer);
 
            //test estimator
            TestEstimatorCore(naiveCalibratorEstimator, calibratorTestData.ScoredData);
 
            Done();
        }
        /// <summary>
        /// OVA and calibrators
        /// </summary>
        [Fact]
        public void PavCalibratorEstimator()
        {
            var calibratorTestData = GetCalibratorTestData();
 
            // pav calibrator
            var pavCalibratorEstimator = new IsotonicCalibratorEstimator(Env);
            var pavCalibratorTransformer = pavCalibratorEstimator.Fit(calibratorTestData.ScoredData);
 
            //check data
            CheckValidCalibratedData(calibratorTestData.ScoredData, pavCalibratorTransformer);
 
            //test estimator
            TestEstimatorCore(pavCalibratorEstimator, calibratorTestData.ScoredData);
 
            Done();
        }
 
        CalibratorTestData GetCalibratorTestData()
        {
            var (pipeline, data) = GetBinaryClassificationPipeline();
            var binaryTrainer = ML.BinaryClassification.Trainers.AveragedPerceptron();
 
            pipeline = pipeline.Append(binaryTrainer);
 
            var transformer = pipeline.Fit(data);
            var scoredData = transformer.Transform(data);
            var scoredDataPreview = scoredData.Preview();
            Assert.True(scoredDataPreview.ColumnView.Length == 6);
 
            return new CalibratorTestData
            {
                Data = data,
                ScoredData = scoredData,
                Pipeline = pipeline,
                Transformer = ((TransformerChain<BinaryPredictionTransformer<LinearBinaryModelParameters>>)transformer).LastTransformer as BinaryPredictionTransformer<LinearBinaryModelParameters>,
            };
        }
 
        private sealed class CalibratorTestData
        {
            public IDataView Data { get; set; }
            public IDataView ScoredData { get; set; }
            public IEstimator<ITransformer> Pipeline { get; set; }
 
            public BinaryPredictionTransformer<LinearBinaryModelParameters> Transformer { get; set; }
        }
 
 
        private void CheckValidCalibratedData(IDataView scoredData, ITransformer transformer)
        {
 
            var calibratedData = transformer.Transform(scoredData).Preview();
 
            Assert.True(calibratedData.ColumnView.Length == 7);
 
            for (int i = 0; i < 10; i++)
            {
                var probability = calibratedData.RowView[i].Values[6];
                Assert.InRange((float)probability.Value, 0, 1);
            }
        }
 
        /// <summary>
        /// Test to confirm calibrator estimators work with classes
        /// where order of label and score columns are reversed, and
        /// where name of score column is different than the default.
        /// </summary>
        [Fact]
        public void TestNonStandardCalibratorEstimatorClasses()
        {
            var mlContext = new MLContext(0);
            // Store different possible variations of calibrator data classes.
            IDataView[] dataArray = new IDataView[]
            {
                mlContext.Data.LoadFromEnumerable<CalibratorTestInputReversedOrder>(
                    new CalibratorTestInputReversedOrder[]
                    {
                        new CalibratorTestInputReversedOrder { Score = 10, Label = true },
                        new CalibratorTestInputReversedOrder { Score = 15, Label = false }
                    }),
                mlContext.Data.LoadFromEnumerable<CalibratorTestInputUniqueScoreColumnName>(
                    new CalibratorTestInputUniqueScoreColumnName[]
                    {
                        new CalibratorTestInputUniqueScoreColumnName { Label = true, ScoreX = 10 },
                        new CalibratorTestInputUniqueScoreColumnName { Label = false, ScoreX = 15 }
                    }),
                mlContext.Data.LoadFromEnumerable<CalibratorTestInputReversedOrderAndUniqueScoreColumnName>(
                    new CalibratorTestInputReversedOrderAndUniqueScoreColumnName[]
                    {
                        new CalibratorTestInputReversedOrderAndUniqueScoreColumnName { ScoreX = 10, Label = true },
                        new CalibratorTestInputReversedOrderAndUniqueScoreColumnName { ScoreX = 15, Label = false }
                    })
            };
 
            // When label and/or score columns are different from their default names ("Label" and "Score", respectively), they
            // need to be manually defined as done below.
            // Successful training of estimators and transforming with transformers indicate correct label and score columns
            // have been found.
            for (int i = 0; i < dataArray.Length; i++)
            {
                // Test PlattCalibratorEstimator
                var calibratorPlattEstimator = new PlattCalibratorEstimator(Env,
                    scoreColumnName: i > 0 ? "ScoreX" : DefaultColumnNames.Score);
                var calibratorPlattTransformer = calibratorPlattEstimator.Fit(dataArray[i]);
                calibratorPlattTransformer.Transform(dataArray[i]);
 
                // Test FixedPlattCalibratorEstimator
                var calibratorFixedPlattEstimator = new FixedPlattCalibratorEstimator(Env,
                    scoreColumn: i > 0 ? "ScoreX" : DefaultColumnNames.Score);
                var calibratorFixedPlattTransformer = calibratorFixedPlattEstimator.Fit(dataArray[i]);
                calibratorFixedPlattTransformer.Transform(dataArray[i]);
 
                // Test NaiveCalibratorEstimator
                var calibratorNaiveEstimator = new NaiveCalibratorEstimator(Env,
                    scoreColumn: i > 0 ? "ScoreX" : DefaultColumnNames.Score);
                var calibratorNaiveTransformer = calibratorNaiveEstimator.Fit(dataArray[i]);
                calibratorNaiveTransformer.Transform(dataArray[i]);
 
                // Test IsotonicCalibratorEstimator
                var calibratorIsotonicEstimator = new IsotonicCalibratorEstimator(Env,
                    scoreColumn: i > 0 ? "ScoreX" : DefaultColumnNames.Score);
                var calibratorIsotonicTransformer = calibratorIsotonicEstimator.Fit(dataArray[i]);
                calibratorIsotonicTransformer.Transform(dataArray[i]);
            }
        }
 
        /// <summary>
        /// Test class where the column order of the label and score
        /// columns are reversed (by default, label column is before
        /// that of score column).
        /// </summary>
        private sealed class CalibratorTestInputReversedOrder
        {
            public float Score { get; set; }
            public bool Label { get; set; }
        }
 
        /// <summary>
        /// Test class where name of score column is different than
        /// the default column name of "Score".
        /// </summary>
        private sealed class CalibratorTestInputUniqueScoreColumnName
        {
            public bool Label { get; set; }
            public float ScoreX { get; set; }
        }
 
        /// <summary>
        /// Test class where the column order of the label and score
        /// columns are reversed (by default, label column is before
        /// that of score column), and where name of score column is
        /// different than the default column name of "Score".
        /// </summary>
        private sealed class CalibratorTestInputReversedOrderAndUniqueScoreColumnName
        {
            public float ScoreX { get; set; }
            public bool Label { get; set; }
        }
 
        /// <summary>
        /// Test to check backwards compatibility of calibrator estimators
        /// trained before the current version of VerWritten: 0x00010001.
        /// </summary>
        [Fact]
        public void TestCalibratorEstimatorBackwardsCompatibility()
        {
            // The legacy model being loaded below was trained and saved with
            // version as such:
            /* 
             * var mlContext = new MLContext(seed: 1);
             * var calibratorTestData = GetCalibratorTestData();
             * var plattCalibratorEstimator = new PlattCalibratorEstimator(Env);
             * var plattCalibratorTransformer = plattCalibratorEstimator.Fit(calibratorTestData.ScoredData);
             * mlContext.Model.Save(plattCalibratorTransformer, calibratorTestData.ScoredData.Schema, "calibrator-model_VerWritten_0x00010001xyz.zip");
             */
 
            var modelPath = GetDataPath("backcompat", "Calibrator_Model_VerWritten_0x00010001.zip");
            ITransformer oldPlattCalibratorTransformer;
            using (var fs = File.OpenRead(modelPath))
                oldPlattCalibratorTransformer = ML.Model.Load(fs, out var schema);
 
            var calibratorTestData = GetCalibratorTestData();
            var newPlattCalibratorEstimator = new PlattCalibratorEstimator(Env);
            var newPlattCalibratorTransformer = newPlattCalibratorEstimator.Fit(calibratorTestData.ScoredData);
 
            // Check that both models produce the same output
            var oldCalibratedData = oldPlattCalibratorTransformer.Transform(calibratorTestData.ScoredData).Preview();
            var newCalibratedData = newPlattCalibratorTransformer.Transform(calibratorTestData.ScoredData).Preview();
 
            // Check first that the produced schemas and outputs are of the same size
            Assert.True(oldCalibratedData.RowView.Length == newCalibratedData.RowView.Length);
            Assert.True(oldCalibratedData.ColumnView.Length == newCalibratedData.ColumnView.Length);
 
            // Then check the produced probabilities (5th value corresponds to probabilities) for
            // equality, within rounding error.
            for (int i = 0; i < 10; i++)
                Assert.True((float)oldCalibratedData.RowView[i].Values[5].Value == (float)newCalibratedData.RowView[i].Values[5].Value);
 
            Done();
        }
    }
}