File: MetricTest.cs
Web Access
Project: src\test\Microsoft.ML.Fairlearn.Tests\Microsoft.ML.Fairlearn.Tests.csproj (Microsoft.ML.Fairlearn.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.Data.Analysis;
using Xunit;
 
namespace Microsoft.ML.Fairlearn.Tests
{
    public class MetricTest
    {
        MLContext mlContext;
        IDataView data;
        public MetricTest()
        {
            mlContext = new MLContext();
            data = mlContext.Data.LoadFromEnumerable(houseData);
        }
 
        public class HouseData
        {
            public float Size { get; set; }
            public float Price { get; set; }
            public float Score { get; set; }
            public string Gender { get; set; }
        }
 
        HouseData[] houseData = {
                new HouseData() { Size = 1.1F, Price = 0.2F, Gender = "Male", Score = 1.2F},
                new HouseData() { Size = 1.9F, Price = 1.3F, Gender = "Male", Score = 2.3F },
                new HouseData() { Size = 2.8F, Price = 3.0F, Gender = "Female", Score = 25.0F },
                new HouseData() { Size = 3.4F, Price = 3.7F, Gender = "Female", Score = 7.7F } };
 
        [Fact]
        public void RegressionMetricTest()
        {
            RegressionGroupMetric regressionMetric = mlContext.Fairlearn().Metric.Regression(eval: data, labelColumn: "Price", scoreColumn: "Score", sensitiveFeatureColumn: "Gender");
            var metricByGroup = regressionMetric.ByGroup();
            Assert.Equal(-2.30578, Convert.ToSingle(metricByGroup["RSquared"][0]), 0.001);
            Assert.Equal(-2039.81453, Convert.ToSingle(metricByGroup["RSquared"][1]), 0.001);
            Assert.Equal(1.00000, Convert.ToSingle(metricByGroup["RMS"][0]), 0.001);
            Assert.Equal(15.811388, Convert.ToSingle(metricByGroup["RMS"][1]), 0.001);
            metricByGroup.Description();
            Dictionary<string, double> metricOverall = regressionMetric.Overall();
            Assert.Equal(125.5, metricOverall["MSE"], 0.1);
            Assert.Equal(11.202678, metricOverall["RMS"], 0.0001);
            Dictionary<string, double> diff = regressionMetric.DifferenceBetweenGroups();
            Assert.Equal(14.81138, diff["RMS"], 0.0001);
            Assert.Equal(2037.5, diff["RSquared"], 0.1);
 
        }
 
        [Fact]
        public void BinaryClassificationMetricTest()
        {
            //create dummy dataset
            bool[] vs = { true, true, true, true, true, true, true, false, false, false };
            PrimitiveDataFrameColumn<bool> label = new PrimitiveDataFrameColumn<bool>("label", vs);
            string[] str = { "a", "b", "a", "a", "b", "a", "b", "b", "a", "b" };
            StringDataFrameColumn groupId = new StringDataFrameColumn("group_id", str);
            bool[] fl = { true, true, true, true, false, false, false, false, false, false };
            PrimitiveDataFrameColumn<bool> pred = new PrimitiveDataFrameColumn<bool>("PredictedLabel", fl);
            float[] fl2 = { 1.0F, 1.0F, 1.0F, 1.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F };
            PrimitiveDataFrameColumn<float> score = new PrimitiveDataFrameColumn<float>("Score", fl2);
            float[] fl3 = { 1.0F, 1.0F, 1.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F };
            PrimitiveDataFrameColumn<float> prob = new PrimitiveDataFrameColumn<float>("Probability", fl3);
            DataFrame df = new DataFrame(label, groupId, pred, score, prob);
 
            BinaryGroupMetric metrics = mlContext.Fairlearn().Metric.BinaryClassification(eval: df, labelColumn: "label", predictedColumn: "PredictedLabel", sensitiveFeatureColumn: "group_id");
            var metricByGroup = metrics.ByGroup();
            Assert.Equal(0.8, Convert.ToSingle(metricByGroup["Accuracy"][0]), 0.1);
            Assert.Equal(0.6, Convert.ToSingle(metricByGroup["Accuracy"][1]), 0.1);
            var metricOverall = metrics.Overall();
            Assert.Equal(0.7, Convert.ToSingle(metricOverall["Accuracy"]), 0.1);
        }
    }
}