|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.IO;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Model;
using Microsoft.ML.TestFramework.Attributes;
using Microsoft.ML.Trainers;
using Xunit;
namespace Microsoft.ML.Tests.TrainerEstimators
{
public partial class TrainerEstimators
{
[Fact]
public void TestEstimatorLogisticRegression()
{
(IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline();
var trainer = ML.BinaryClassification.Trainers.LbfgsLogisticRegression();
var pipeWithTrainer = pipe.Append(trainer);
TestEstimatorCore(pipeWithTrainer, dataView);
var transformedDataView = pipe.Fit(dataView).Transform(dataView);
var model = trainer.Fit(transformedDataView);
trainer.Fit(transformedDataView, model.Model.SubModel);
Done();
}
[Fact]
public void TestEstimatorMulticlassLogisticRegression()
{
(IEstimator<ITransformer> pipe, IDataView dataView) = GetMulticlassPipeline();
var trainer = ML.MulticlassClassification.Trainers.LbfgsMaximumEntropy();
var pipeWithTrainer = pipe.Append(trainer);
TestEstimatorCore(pipeWithTrainer, dataView);
var transformedDataView = pipe.Fit(dataView).Transform(dataView);
var model = trainer.Fit(transformedDataView);
trainer.Fit(transformedDataView, model.Model);
Done();
}
[Fact]
public void TestEstimatorPoissonRegression()
{
var dataView = GetRegressionPipeline();
var trainer = ML.Regression.Trainers.LbfgsPoissonRegression();
TestEstimatorCore(trainer, dataView);
var model = trainer.Fit(dataView);
trainer.Fit(dataView, model.Model);
Done();
}
[Fact]
public void TestLRNoStats()
{
(IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline();
pipe = pipe.Append(ML.BinaryClassification.Trainers.LbfgsLogisticRegression(new LbfgsLogisticRegressionBinaryTrainer.Options { ShowTrainingStatistics = true }));
var transformerChain = pipe.Fit(dataView) as TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>>;
var linearModel = transformerChain.LastTransformer.Model.SubModel as LinearBinaryModelParameters;
var stats = linearModel.Statistics as ModelStatisticsBase;
Assert.NotNull(stats);
var stats2 = linearModel.Statistics as LinearModelParameterStatistics;
Assert.Null(stats2);
Done();
}
[NativeDependencyFact("MklImports")]
public void TestLRWithStats()
{
(IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline();
pipe = pipe.Append(ML.BinaryClassification.Trainers.LbfgsLogisticRegression(
new LbfgsLogisticRegressionBinaryTrainer.Options
{
ShowTrainingStatistics = true,
ComputeStandardDeviation = new ComputeLRTrainingStdThroughMkl(),
}));
var transformer = pipe.Fit(dataView) as TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>>;
var linearModel = transformer.LastTransformer.Model.SubModel as LinearBinaryModelParameters;
Action<LinearBinaryModelParameters> validateStats = (modelParameters) =>
{
var stats = linearModel.Statistics as LinearModelParameterStatistics;
var biasStats = stats?.GetBiasStatistics();
Assert.NotNull(biasStats);
biasStats = stats.GetBiasStatisticsForValue(2);
Assert.NotNull(biasStats);
CompareNumbersWithTolerance(biasStats.StandardError, 0.24, digitsOfPrecision: 2);
CompareNumbersWithTolerance(biasStats.ZScore, 8.32, digitsOfPrecision: 2);
var scoredData = transformer.Transform(dataView);
var coefficients = stats.GetWeightsCoefficientStatistics(100);
Assert.Equal(17, coefficients.Length);
foreach (var coefficient in coefficients)
Assert.True(coefficient.StandardError < 1.0);
};
validateStats(linearModel);
var modelAndSchemaPath = GetOutputPath("TestLRWithStats.zip");
// Save model.
ML.Model.Save(transformer, dataView.Schema, modelAndSchemaPath);
ITransformer transformerChain;
using (var fs = File.OpenRead(modelAndSchemaPath))
transformerChain = ML.Model.Load(fs, out var schema);
var lastTransformer = ((TransformerChain<ITransformer>)transformerChain).LastTransformer as BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>;
var model = lastTransformer.Model;
linearModel = model.SubModel as LinearBinaryModelParameters;
validateStats(linearModel);
Done();
}
[Fact]
public void TestLRWithStatsBackCompatibility()
{
string dropModelPath = GetDataPath("backcompat/LrWithStats.zip");
string trainData = GetDataPath("adult.tiny.with-schema.txt");
using (FileStream fs = File.OpenRead(dropModelPath))
{
var result = ModelFileUtils.LoadPredictorOrNull(Env, fs) as CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>;
var subPredictor = result?.SubModel as LinearBinaryModelParameters;
var stats = subPredictor?.Statistics;
CompareNumbersWithTolerance(stats.Deviance, 458.970917);
CompareNumbersWithTolerance(stats.NullDeviance, 539.276367);
Assert.Equal(7, stats.ParametersCount);
Assert.Equal(500, stats.TrainingExampleCount);
}
Done();
}
[Fact]
public void TestMLRNoStats()
{
(IEstimator<ITransformer> pipe, IDataView dataView) = GetMulticlassPipeline();
var trainer = ML.MulticlassClassification.Trainers.LbfgsMaximumEntropy();
var pipeWithTrainer = pipe.Append(trainer);
TestEstimatorCore(pipeWithTrainer, dataView);
var transformer = pipeWithTrainer.Fit(dataView);
var model = transformer.LastTransformer.Model as MaximumEntropyModelParameters;
var stats = model.Statistics;
Assert.Null(stats);
Done();
}
[Fact]
public void TestMLRWithStats()
{
(IEstimator<ITransformer> pipe, IDataView dataView) = GetMulticlassPipeline();
var trainer = ML.MulticlassClassification.Trainers.LbfgsMaximumEntropy(new LbfgsMaximumEntropyMulticlassTrainer.Options
{
ShowTrainingStatistics = true
});
var pipeWithTrainer = pipe.Append(trainer);
TestEstimatorCore(pipeWithTrainer, dataView);
var transformer = pipeWithTrainer.Fit(dataView);
var model = transformer.LastTransformer.Model as MaximumEntropyModelParameters;
Action<MaximumEntropyModelParameters> validateStats = (modelParams) =>
{
var stats = modelParams.Statistics;
Assert.NotNull(stats);
#if NETCOREAPP3_1_OR_GREATER
CompareNumbersWithTolerance(stats.Deviance, 45.79, digitsOfPrecision: 0);
CompareNumbersWithTolerance(stats.NullDeviance, 329.58, digitsOfPrecision: 2);
#else
CompareNumbersWithTolerance(stats.Deviance, 45.35, digitsOfPrecision: 0);
CompareNumbersWithTolerance(stats.NullDeviance, 329.58, digitsOfPrecision: 2);
#endif
//Assert.Equal(14, stats.ParametersCount);
Assert.Equal(150, stats.TrainingExampleCount);
};
validateStats(model);
var modelAndSchemaPath = GetOutputPath("TestMLRWithStats.zip");
// Save model.
ML.Model.Save(transformer, dataView.Schema, modelAndSchemaPath);
// Load model.
ITransformer transformerChain;
using (var fs = File.OpenRead(modelAndSchemaPath))
transformerChain = ML.Model.Load(fs, out var schema);
var lastTransformer = ((TransformerChain<ITransformer>)transformerChain).LastTransformer as MulticlassPredictionTransformer<MaximumEntropyModelParameters>;
model = lastTransformer.Model;
validateStats(model);
Done();
}
[Fact]
public void TestMLRWithStatsBackCompatibility()
{
string dropModelPath = GetDataPath("backcompat/MlrWithStats.zip");
string trainData = GetDataPath("iris.data");
using (FileStream fs = File.OpenRead(dropModelPath))
{
var result = ModelFileUtils.LoadPredictorOrNull(Env, fs) as MaximumEntropyModelParameters;
var stats = result?.Statistics;
Assert.Equal(132.012238f, stats.Deviance);
Assert.Equal(329.583679f, stats.NullDeviance);
Assert.Equal(11, stats.ParametersCount);
Assert.Equal(150, stats.TrainingExampleCount);
}
Done();
}
}
}
|