File: MetricsAgentsTests.cs
Web Access
Project: src\test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj (Microsoft.ML.AutoML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.AutoML.Test
{
    public class MetricsAgentsTests : BaseTestClass
    {
        public MetricsAgentsTests(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        public void BinaryMetricsGetScoreTest()
        {
            var metrics = MetricsUtil.CreateBinaryClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8);
            Assert.Equal(0.1, GetScore(metrics, BinaryClassificationMetric.AreaUnderRocCurve));
            Assert.Equal(0.2, GetScore(metrics, BinaryClassificationMetric.Accuracy));
            Assert.Equal(0.3, GetScore(metrics, BinaryClassificationMetric.PositivePrecision));
            Assert.Equal(0.4, GetScore(metrics, BinaryClassificationMetric.PositiveRecall));
            Assert.Equal(0.5, GetScore(metrics, BinaryClassificationMetric.NegativePrecision));
            Assert.Equal(0.6, GetScore(metrics, BinaryClassificationMetric.NegativeRecall));
            Assert.Equal(0.7, GetScore(metrics, BinaryClassificationMetric.F1Score));
            Assert.Equal(0.8, GetScore(metrics, BinaryClassificationMetric.AreaUnderPrecisionRecallCurve));
        }
 
        [Fact]
        public void BinaryMetricsNonPerfectTest()
        {
            var metrics = MetricsUtil.CreateBinaryClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8);
            Assert.False(IsPerfectModel(metrics, BinaryClassificationMetric.Accuracy));
            Assert.False(IsPerfectModel(metrics, BinaryClassificationMetric.AreaUnderRocCurve));
            Assert.False(IsPerfectModel(metrics, BinaryClassificationMetric.AreaUnderPrecisionRecallCurve));
            Assert.False(IsPerfectModel(metrics, BinaryClassificationMetric.F1Score));
            Assert.False(IsPerfectModel(metrics, BinaryClassificationMetric.NegativePrecision));
            Assert.False(IsPerfectModel(metrics, BinaryClassificationMetric.NegativeRecall));
            Assert.False(IsPerfectModel(metrics, BinaryClassificationMetric.PositivePrecision));
            Assert.False(IsPerfectModel(metrics, BinaryClassificationMetric.PositiveRecall));
        }
 
        [Fact]
        public void BinaryMetricsPerfectTest()
        {
            var metrics = MetricsUtil.CreateBinaryClassificationMetrics(1, 1, 1, 1, 1, 1, 1, 1);
            Assert.True(IsPerfectModel(metrics, BinaryClassificationMetric.Accuracy));
            Assert.True(IsPerfectModel(metrics, BinaryClassificationMetric.AreaUnderRocCurve));
            Assert.True(IsPerfectModel(metrics, BinaryClassificationMetric.AreaUnderPrecisionRecallCurve));
            Assert.True(IsPerfectModel(metrics, BinaryClassificationMetric.F1Score));
            Assert.True(IsPerfectModel(metrics, BinaryClassificationMetric.NegativePrecision));
            Assert.True(IsPerfectModel(metrics, BinaryClassificationMetric.NegativeRecall));
            Assert.True(IsPerfectModel(metrics, BinaryClassificationMetric.PositivePrecision));
            Assert.True(IsPerfectModel(metrics, BinaryClassificationMetric.PositiveRecall));
        }
 
        [Fact]
        public void MulticlassMetricsGetScoreTest()
        {
            var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, new double[] { 0.5 }, new double[] { });
            Assert.Equal(0.1, GetScore(metrics, MulticlassClassificationMetric.MicroAccuracy));
            Assert.Equal(0.2, GetScore(metrics, MulticlassClassificationMetric.MacroAccuracy));
            Assert.Equal(0.3, GetScore(metrics, MulticlassClassificationMetric.LogLoss));
            Assert.Equal(0.4, GetScore(metrics, MulticlassClassificationMetric.LogLossReduction));
            Assert.Equal(0.5, GetScore(metrics, MulticlassClassificationMetric.TopKAccuracy));
        }
 
        [Fact]
        public void MulticlassMetricsNonPerfectTest()
        {
            var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, new double[] { 0.5 }, new double[] { });
            Assert.False(IsPerfectModel(metrics, MulticlassClassificationMetric.MacroAccuracy));
            Assert.False(IsPerfectModel(metrics, MulticlassClassificationMetric.MicroAccuracy));
            Assert.False(IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss));
            Assert.False(IsPerfectModel(metrics, MulticlassClassificationMetric.LogLossReduction));
            Assert.False(IsPerfectModel(metrics, MulticlassClassificationMetric.TopKAccuracy));
        }
 
        [Fact]
        public void MulticlassMetricsPerfectTest()
        {
            var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(1, 1, 0, 1, 0, new double[] { 1 }, new double[] { });
            Assert.True(IsPerfectModel(metrics, MulticlassClassificationMetric.MicroAccuracy));
            Assert.True(IsPerfectModel(metrics, MulticlassClassificationMetric.MacroAccuracy));
            Assert.True(IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss));
            Assert.True(IsPerfectModel(metrics, MulticlassClassificationMetric.LogLossReduction));
            Assert.True(IsPerfectModel(metrics, MulticlassClassificationMetric.TopKAccuracy));
        }
 
        [Fact]
        public void RegressionMetricsGetScoreTest()
        {
            var metrics = MetricsUtil.CreateRegressionMetrics(0.2, 0.3, 0.4, 0.5, 0.6);
            Assert.Equal(0.2, GetScore(metrics, RegressionMetric.MeanAbsoluteError));
            Assert.Equal(0.3, GetScore(metrics, RegressionMetric.MeanSquaredError));
            Assert.Equal(0.4, GetScore(metrics, RegressionMetric.RootMeanSquaredError));
            Assert.Equal(0.6, GetScore(metrics, RegressionMetric.RSquared));
        }
 
        [Fact]
        public void RegressionMetricsNonPerfectTest()
        {
            var metrics = MetricsUtil.CreateRegressionMetrics(0.2, 0.3, 0.4, 0.5, 0.6);
            Assert.False(IsPerfectModel(metrics, RegressionMetric.MeanAbsoluteError));
            Assert.False(IsPerfectModel(metrics, RegressionMetric.MeanSquaredError));
            Assert.False(IsPerfectModel(metrics, RegressionMetric.RootMeanSquaredError));
            Assert.False(IsPerfectModel(metrics, RegressionMetric.RSquared));
        }
 
        [Fact]
        public void RegressionMetricsPerfectTest()
        {
            var metrics = MetricsUtil.CreateRegressionMetrics(0, 0, 0, 0, 1);
            Assert.True(IsPerfectModel(metrics, RegressionMetric.MeanAbsoluteError));
            Assert.True(IsPerfectModel(metrics, RegressionMetric.MeanSquaredError));
            Assert.True(IsPerfectModel(metrics, RegressionMetric.RootMeanSquaredError));
            Assert.True(IsPerfectModel(metrics, RegressionMetric.RSquared));
        }
 
        [Fact]
        public void RankingMetricsGetScoreTest()
        {
            double[] ndcg = { 0.2, 0.3, 0.4 };
            double[] dcg = { 0.2, 0.3, 0.4 };
            var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
            Assert.Equal(0.4, GetScore(metrics, RankingMetric.Dcg, 3));
            Assert.Equal(0.4, GetScore(metrics, RankingMetric.Ndcg, 3));
 
            double[] largeNdcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 };
            double[] largeDcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 };
            metrics = MetricsUtil.CreateRankingMetrics(largeDcg, largeNdcg);
            Assert.Equal(0.3, GetScore(metrics, RankingMetric.Dcg, 3));
            Assert.Equal(0.3, GetScore(metrics, RankingMetric.Ndcg, 3));
        }
 
        [Fact]
        public void RankingMetricsNonPerfectTest()
        {
            double[] ndcg = { 0.2, 0.3, 0.4 };
            double[] dcg = { 0.2, 0.3, 0.4 };
            var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
            Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, 3));
            Assert.False(IsPerfectModel(metrics, RankingMetric.Ndcg, 3));
        }
 
        [Fact]
        public void RankingMetricsPerfectTest()
        {
            double[] ndcg = { 0.2, 0.3, 1 };
            double[] dcg = { 0.2, 0.3, 1 };
            var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
            Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, 3)); //REVIEW: No true Perfect model
            Assert.True(IsPerfectModel(metrics, RankingMetric.Ndcg, 3));
        }
 
        [Fact]
        public void ThrowNotSupportedMetricException()
        {
            var ex = MetricsAgentUtil.BuildMetricNotSupportedException(BinaryClassificationMetric.Accuracy);
            Assert.Equal(typeof(NotSupportedException), ex.GetType());
        }
 
        private static double GetScore(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric)
        {
            return new BinaryMetricsAgent(null, metric).GetScore(metrics);
        }
 
        private static double GetScore(MulticlassClassificationMetrics metrics, MulticlassClassificationMetric metric)
        {
            return new MultiMetricsAgent(null, metric).GetScore(metrics);
        }
 
        private static double GetScore(RegressionMetrics metrics, RegressionMetric metric)
        {
            return new RegressionMetricsAgent(null, metric).GetScore(metrics);
        }
 
        private static double GetScore(RankingMetrics metrics, RankingMetric metric, uint dcgTruncationLevel)
        {
            return new RankingMetricsAgent(null, metric, dcgTruncationLevel).GetScore(metrics);
        }
 
        private static bool IsPerfectModel(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric)
        {
            var metricsAgent = new BinaryMetricsAgent(null, metric);
            return IsPerfectModel(metricsAgent, metrics);
        }
 
        private static bool IsPerfectModel(MulticlassClassificationMetrics metrics, MulticlassClassificationMetric metric)
        {
            var metricsAgent = new MultiMetricsAgent(null, metric);
            return IsPerfectModel(metricsAgent, metrics);
        }
 
        private static bool IsPerfectModel(RegressionMetrics metrics, RegressionMetric metric)
        {
            var metricsAgent = new RegressionMetricsAgent(null, metric);
            return IsPerfectModel(metricsAgent, metrics);
        }
 
        private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric, uint dcgTruncationLevel)
        {
            var metricsAgent = new RankingMetricsAgent(null, metric, dcgTruncationLevel);
            return IsPerfectModel(metricsAgent, metrics);
        }
 
        private static bool IsPerfectModel<TMetrics>(IMetricsAgent<TMetrics> metricsAgent, TMetrics metrics)
        {
            var score = metricsAgent.GetScore(metrics);
            return metricsAgent.IsModelPerfect(score);
        }
    }
}