File: Common.cs
Web Access
Project: src\test\Microsoft.ML.IntegrationTests\Microsoft.ML.IntegrationTests.csproj (Microsoft.ML.IntegrationTests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.IntegrationTests.Datasets;
using Xunit;
using Xunit.Sdk;
 
namespace Microsoft.ML.IntegrationTests
{
    internal static class Common
    {
        /// <summary>
        /// Assert that an <see cref="IDataView"/> rows are of <see cref="TypeTestData"/>.
        /// </summary>
        /// <param name="testTypeDataset">An <see cref="IDataView"/>.</param>
        public static void AssertTypeTestDataset(IDataView testTypeDataset)
        {
            var toyClassProperties = typeof(TypeTestData).GetProperties();
 
            // Check that the schema is of the right size.
            Assert.Equal(toyClassProperties.Length, testTypeDataset.Schema.Count);
 
            // Create a lookup table for the types and counts of all properties.
            var types = new Dictionary<string, Type>();
            var counts = new Dictionary<string, int>();
            foreach (var property in toyClassProperties)
            {
                if (!property.PropertyType.IsArray)
                    types[property.Name] = property.PropertyType;
                else
                {
                    // Construct a VBuffer type for the array.
                    var vBufferType = typeof(VBuffer<>);
                    Type[] typeArgs = { property.PropertyType.GetElementType() };
                    Activator.CreateInstance(property.PropertyType.GetElementType());
                    types[property.Name] = vBufferType.MakeGenericType(typeArgs);
                }
 
                counts[property.Name] = 0;
            }
 
            foreach (var column in testTypeDataset.Schema)
            {
                Assert.True(types.ContainsKey(column.Name));
                Assert.Equal(1, ++counts[column.Name]);
                Assert.Equal(types[column.Name], column.Type.RawType);
            }
 
            // Make sure we didn't miss any columns.
            foreach (var value in counts.Values)
                Assert.Equal(1, value);
        }
 
        /// <summary>
        /// Assert than two <see cref="TypeTestData"/> datasets are equal.
        /// </summary>
        /// <param name="mlContext">The ML Context.</param>
        /// <param name="data1">A <see cref="IDataView"/> of <see cref="TypeTestData"/></param>
        /// <param name="data2">A <see cref="IDataView"/> of <see cref="TypeTestData"/></param>
        public static void AssertTestTypeDatasetsAreEqual(MLContext mlContext, IDataView data1, IDataView data2)
        {
            // Confirm that they are both of the property row type.
            AssertTypeTestDataset(data1);
            AssertTypeTestDataset(data2);
 
            // Validate that the two Schemas are the same.
            Common.AssertEqual(data1.Schema, data2.Schema);
 
            // Define how to serialize the IDataView to objects.
            var enumerable1 = mlContext.Data.CreateEnumerable<TypeTestData>(data1, true);
            var enumerable2 = mlContext.Data.CreateEnumerable<TypeTestData>(data2, true);
 
            AssertEqual(enumerable1, enumerable2);
        }
 
        /// <summary>
        /// Assert that two float arrays are equal.
        /// </summary>
        /// <param name="array1">An array of floats.</param>
        /// <param name="array2">An array of floats.</param>
        public static void AssertEqual(float[] array1, float[] array2, int precision = 6)
        {
            Assert.NotNull(array1);
            Assert.NotNull(array2);
            Assert.Equal(array1.Length, array2.Length);
 
            for (int i = 0; i < array1.Length; i++)
                Assert.Equal(array1[i], array2[i], precision: precision);
        }
 
        /// <summary>
        ///  Assert that two <see cref="DataViewSchema"/> objects are equal.
        /// </summary>
        /// <param name="schema1">A <see cref="DataViewSchema"/> object.</param>
        /// <param name="schema2">A <see cref="DataViewSchema"/> object.</param>
        public static void AssertEqual(DataViewSchema schema1, DataViewSchema schema2)
        {
            Assert.NotNull(schema1);
            Assert.NotNull(schema2);
 
            Assert.Equal(schema1.Count(), schema2.Count());
 
            foreach (var schemaPair in schema1.Zip(schema2, Tuple.Create))
            {
                Assert.Equal(schemaPair.Item1.Name, schemaPair.Item2.Name);
                Assert.Equal(schemaPair.Item1.Index, schemaPair.Item2.Index);
                Assert.Equal(schemaPair.Item1.IsHidden, schemaPair.Item2.IsHidden);
                // Can probably do a better comparison of Metadata.
                AssertEqual(schemaPair.Item1.Annotations.Schema, schemaPair.Item1.Annotations.Schema);
                Assert.True((schemaPair.Item1.Type == schemaPair.Item2.Type) ||
                    (schemaPair.Item1.Type.RawType == schemaPair.Item2.Type.RawType));
            }
        }
 
        /// <summary>
        /// Assert than two <see cref="TypeTestData"/> enumerables are equal.
        /// </summary>
        /// <param name="data1">An enumerable of <see cref="TypeTestData"/></param>
        /// <param name="data2">An enumerable of <see cref="TypeTestData"/></param>
        public static void AssertEqual(IEnumerable<TypeTestData> data1, IEnumerable<TypeTestData> data2)
        {
            Assert.NotNull(data1);
            Assert.NotNull(data2);
            Assert.Equal(data1.Count(), data2.Count());
 
            foreach (var rowPair in data1.Zip(data2, Tuple.Create))
            {
                AssertEqual(rowPair.Item1, rowPair.Item2);
            }
        }
 
        /// <summary>
        /// Assert that two TypeTest datasets are equal.
        /// </summary>
        /// <param name="testType1">An <see cref="TypeTestData"/>.</param>
        /// <param name="testType2">An <see cref="TypeTestData"/>.</param>
        public static void AssertEqual(TypeTestData testType1, TypeTestData testType2)
        {
            Assert.Equal(testType1.Label, testType2.Label);
            Common.AssertEqual(testType1.Features, testType2.Features);
            Assert.Equal(testType1.I1, testType2.I1);
            Assert.Equal(testType1.U1, testType2.U1);
            Assert.Equal(testType1.I2, testType2.I2);
            Assert.Equal(testType1.U2, testType2.U2);
            Assert.Equal(testType1.I4, testType2.I4);
            Assert.Equal(testType1.U4, testType2.U4);
            Assert.Equal(testType1.I8, testType2.I8);
            Assert.Equal(testType1.U8, testType2.U8);
            Assert.Equal(testType1.R4, testType2.R4);
            Assert.Equal(testType1.R8, testType2.R8);
            Assert.Equal(testType1.Tx.ToString(), testType2.Tx.ToString());
            Assert.True(testType1.Ts.Equals(testType2.Ts));
            Assert.True(testType1.Dt.Equals(testType2.Dt));
            Assert.True(testType1.Dz.Equals(testType2.Dz));
        }
 
        /// <summary>
        /// Check that a <see cref="AnomalyDetectionMetrics"/> object is valid.
        /// </summary>
        /// <param name="metrics">The metrics object.</param>
        public static void AssertMetrics(AnomalyDetectionMetrics metrics)
        {
            Assert.InRange(metrics.AreaUnderRocCurve, 0, 1);
            Assert.InRange(metrics.DetectionRateAtFalsePositiveCount, 0, 1);
        }
 
        /// <summary>
        /// Check that a <see cref="BinaryClassificationMetrics"/> object is valid.
        /// </summary>
        /// <param name="metrics">The metrics object.</param>
        public static void AssertMetrics(BinaryClassificationMetrics metrics)
        {
            Assert.InRange(metrics.Accuracy, 0, 1);
            Assert.InRange(metrics.AreaUnderRocCurve, 0, 1);
            Assert.InRange(metrics.AreaUnderPrecisionRecallCurve, 0, 1);
            Assert.InRange(metrics.F1Score, 0, 1);
            Assert.InRange(metrics.NegativePrecision, 0, 1);
            Assert.InRange(metrics.NegativeRecall, 0, 1);
            Assert.InRange(metrics.PositivePrecision, 0, 1);
            Assert.InRange(metrics.PositiveRecall, 0, 1);
 
            // Confusion matrix validations
            Assert.NotNull(metrics.ConfusionMatrix);
            AssertConfusionMatrix(metrics.ConfusionMatrix);
        }
 
        /// <summary>
        /// Check that a <see cref="CalibratedBinaryClassificationMetrics"/> object is valid.
        /// </summary>
        /// <param name="metrics">The metrics object.</param>
        public static void AssertMetrics(CalibratedBinaryClassificationMetrics metrics)
        {
            Assert.InRange(metrics.Entropy, double.NegativeInfinity, 1);
            Assert.InRange(metrics.LogLoss, double.NegativeInfinity, 1);
            Assert.InRange(metrics.LogLossReduction, double.NegativeInfinity, 100);
            AssertMetrics(metrics as BinaryClassificationMetrics);
 
            // Confusion matrix validations
            Assert.NotNull(metrics.ConfusionMatrix);
            AssertConfusionMatrix(metrics.ConfusionMatrix);
        }
 
        /// <summary>
        /// Check that a <see cref="ClusteringMetrics"/> object is valid.
        /// </summary>
        /// <param name="metrics">The metrics object.</param>
        public static void AssertMetrics(ClusteringMetrics metrics)
        {
            Assert.True(metrics.AverageDistance >= 0);
            Assert.True(metrics.DaviesBouldinIndex >= 0);
            if (!double.IsNaN(metrics.NormalizedMutualInformation))
                Assert.True(metrics.NormalizedMutualInformation >= 0 && metrics.NormalizedMutualInformation <= 1);
        }
 
        /// <summary>
        /// Check that a <see cref="MulticlassClassificationMetrics"/> object is valid.
        /// </summary>
        /// <param name="metrics">The metrics object.</param>
        public static void AssertMetrics(MulticlassClassificationMetrics metrics)
        {
            Assert.InRange(metrics.MacroAccuracy, 0, 1);
            Assert.InRange(metrics.MicroAccuracy, 0, 1);
            Assert.True(metrics.LogLoss >= 0);
            Assert.InRange(metrics.TopKAccuracy, 0, 1);
 
            // Confusion matrix validations
            Assert.NotNull(metrics.ConfusionMatrix);
            AssertConfusionMatrix(metrics.ConfusionMatrix);
 
        }
 
        internal static void AssertConfusionMatrix(ConfusionMatrix confusionMatrix)
        {
            // Confusion matrix validations
            Assert.NotNull(confusionMatrix);
            Assert.NotEmpty(confusionMatrix.Counts);
            Assert.NotEmpty(confusionMatrix.PerClassPrecision);
            Assert.NotEmpty(confusionMatrix.PerClassRecall);
 
            foreach (var precision in confusionMatrix.PerClassPrecision)
                Assert.InRange(precision, 0, 1);
 
            foreach (var recall in confusionMatrix.PerClassRecall)
                Assert.InRange(recall, 0, 1);
 
        }
 
        /// <summary>
        /// Check that a <see cref="RankingMetrics"/> object is valid.
        /// </summary>
        /// <param name="metrics">The metrics object.</param>
        public static void AssertMetrics(RankingMetrics metrics)
        {
            foreach (var dcg in metrics.DiscountedCumulativeGains)
                Assert.True(dcg >= 0);
            foreach (var ndcg in metrics.NormalizedDiscountedCumulativeGains)
                Assert.InRange(ndcg, 0, 100);
        }
 
        /// <summary>
        /// Check that a <see cref="RegressionMetrics"/> object is valid.
        /// </summary>
        /// <param name="metrics">The metrics object.</param>
        public static void AssertMetrics(RegressionMetrics metrics)
        {
            Assert.True(metrics.RootMeanSquaredError >= 0);
            Assert.True(metrics.MeanAbsoluteError >= 0);
            Assert.True(metrics.MeanSquaredError >= 0);
            Assert.True(metrics.RSquared <= 1);
        }
 
        /// <summary>
        /// Check that a <see cref="MetricStatistics"/> object is valid.
        /// </summary>
        /// <param name="metric">The <see cref="MetricStatistics"/> object.</param>
        public static void AssertMetricStatistics(MetricStatistics metric)
        {
            Assert.True(metric.StandardDeviation >= 0);
            Assert.True(metric.StandardError >= 0);
        }
 
        /// <summary>
        /// Check that a <see cref="RegressionMetricsStatistics"/> object is valid.
        /// </summary>
        /// <param name="metrics">The metrics object.</param>
        public static void AssertMetricsStatistics(RegressionMetricsStatistics metrics)
        {
            AssertMetricStatistics(metrics.RootMeanSquaredError);
            AssertMetricStatistics(metrics.MeanAbsoluteError);
            AssertMetricStatistics(metrics.MeanSquaredError);
            AssertMetricStatistics(metrics.RSquared);
            AssertMetricStatistics(metrics.LossFunction);
        }
 
        /// <summary>
        /// Assert that two float arrays are not equal.
        /// </summary>
        /// <param name="array1">An array of floats.</param>
        /// <param name="array2">An array of floats.</param>
        public static void AssertNotEqual(float[] array1, float[] array2)
        {
            Assert.NotNull(array1);
            Assert.NotNull(array2);
            Assert.Equal(array1.Length, array2.Length);
 
            bool mismatch = false;
            for (int i = 0; i < array1.Length; i++)
                try
                {
                    // Use Assert to test for equality rather than
                    // to roll our own float equality checker.
                    Assert.Equal(array1[i], array2[i]);
                }
                catch (EqualException)
                {
                    mismatch = true;
                    break;
                }
            Assert.True(mismatch);
        }
 
        /// <summary>
        /// Verify that a float array has no NaNs or infinities.
        /// </summary>
        /// <param name="array">An array of doubles.</param>
        public static void AssertFiniteNumbers(IList<float> array, int ignoreElementAt = -1)
        {
            for (int i = 0; i < array.Count; i++)
            {
                if (i == ignoreElementAt)
                    continue;
                Assert.False(float.IsNaN(array[i]));
                Assert.False(float.IsInfinity(array[i]));
            }
        }
 
        /// <summary>
        /// Verify that a double array has no NaNs or infinities.
        /// </summary>
        /// <param name="array">An array of doubles.</param>
        public static void AssertFiniteNumbers(IList<double> array, int ignoreElementAt = -1)
        {
            for (int i = 0; i < array.Count; i++)
            {
                if (i == ignoreElementAt)
                    continue;
                Assert.False(double.IsNaN(array[i]));
                Assert.False(double.IsInfinity(array[i]));
            }
        }
    }
}