File: TestGamPublicInterfaces.cs
Web Access
Project: src\test\Microsoft.ML.Predictor.Tests\Microsoft.ML.Predictor.Tests.csproj (Microsoft.ML.Predictor.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Linq;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.TestFramework;
using Microsoft.ML.Trainers.FastTree;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.RunTests
{
    public class TestGamPublicInterfaces : BaseTestClass
    {
        public TestGamPublicInterfaces(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        [TestCategory("FastTree")]
        public void TestGamDirectInstantiation()
        {
            var mlContext = new MLContext(seed: 1);
 
            double intercept = 1;
 
            var binUpperBounds = new double[2][]
            {
                new double[] { 1, 2, 3 },
                new double[] { 4, 5, 6 }
            };
 
            var binEffects = new double[2][]
            {
                new double[] { 0, 1, 2 },
                new double[] { 2, 1, 0 }
            };
 
            var gam = new GamRegressionModelParameters(mlContext, binUpperBounds, binEffects, intercept);
 
            // Check that the model has the right number of shape functions
            Assert.Equal(binUpperBounds.Length, gam.NumberOfShapeFunctions);
 
            // Check the intercept
            Assert.Equal(intercept, gam.Bias, 0.000001);
 
            // Check that the binUpperBounds were made correctly
            CheckArrayOfArrayEquality(binUpperBounds, gam.GetBinUpperBounds());
            for (int i = 0; i < gam.NumberOfShapeFunctions; i++)
                Utils.AreEqual(binUpperBounds[i], gam.GetBinUpperBounds(i).ToArray());
 
            // Check that the bin effects were made correctly
            CheckArrayOfArrayEquality(binEffects, gam.GetBinEffects());
            for (int i = 0; i < gam.NumberOfShapeFunctions; i++)
                Utils.AreEqual(binEffects[i], gam.GetBinEffects(i).ToArray());
 
            // Check that the constructor handles null inputs properly
            Assert.Throws<System.ArgumentNullException>(() => new GamRegressionModelParameters(mlContext, binUpperBounds, null, intercept));
            Assert.Throws<System.ArgumentNullException>(() => new GamRegressionModelParameters(mlContext, null, binEffects, intercept));
            Assert.Throws<System.ArgumentNullException>(() => new GamRegressionModelParameters(mlContext, null, null, intercept));
 
            // Check that the constructor handles mismatches in length between bin upper bounds and bin effects
            var misMatchArray = new double[1][];
            misMatchArray[0] = new double[] { 0 };
            Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, binUpperBounds, misMatchArray, intercept));
            Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, misMatchArray, binEffects, intercept));
 
            // Check that the constructor handles a mismatch in bin upper bounds and bin effects for a feature
            var fewerBinEffects = new double[2][]
            {
                new double[] { 0, 1 },
                new double[] { 2, 1, 0 }
            };
            Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, binUpperBounds, fewerBinEffects, intercept));
            var moreBinEffects = new double[2][]
            {
                new double[] { 0, 1, 2, 3 },
                new double[] { 2, 1, 0 }
            };
            Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, binUpperBounds, moreBinEffects, intercept));
 
            // Check that the constructor handles bin upper bounds that are not sorted
            var unsortedUpperBounds = new double[2][]
            {
                new double[] { 1, 3, 2 },
                new double[] { 4, 5, 6 }
            };
            Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, unsortedUpperBounds, binEffects, intercept));
        }
 
        private void CheckArrayOfArrayEquality(double[][] array1, double[][] array2)
        {
            Assert.Equal(array1.Length, array2.Length);
            for (int i = 0; i < array1.Length; i++)
                Utils.AreEqual(array1[i], array2[i]);
        }
    }
}