File: TrainerExtensionsTests.cs
Web Access
Project: src\test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj (Microsoft.ML.AutoML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.TestFramework;
using Microsoft.ML.TestFramework.Attributes;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.AutoML.Test
{
 
    public class TrainerExtensionsTests : BaseTestClass
    {
        public TrainerExtensionsTests(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        public void TrainerExtensionInstanceTests()
        {
            var context = new MLContext(1);
            var columnInfo = new ColumnInformation();
            var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast<TrainerName>()
                .Except(new[] { TrainerName.Ova });
            foreach (var trainerName in trainerNames)
            {
                var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
 
                IEnumerable<SweepableParam> sweepParams = null;
                if (trainerName != TrainerName.ImageClassification)
                {
                    sweepParams = extension.GetHyperparamSweepRanges();
                    Assert.NotNull(sweepParams);
                    foreach (var sweepParam in sweepParams)
                    {
                        sweepParam.RawValue = 1;
                    }
 
                    var instance = extension.CreateInstance(context, sweepParams, columnInfo);
                    Assert.NotNull(instance);
                    var pipelineNode = extension.CreatePipelineNode(null, columnInfo);
                    Assert.NotNull(pipelineNode);
                }
            }
        }
 
        [TensorFlowFact]
        public void TrainerExtensionTensorFlowInstanceTests()
        {
            var context = new MLContext(1);
            var columnInfo = new ColumnInformation();
            var extension = TrainerExtensionCatalog.GetTrainerExtension(TrainerName.ImageClassification);
            var instance = extension.CreateInstance(context, null, columnInfo);
            Assert.NotNull(instance);
            var pipelineNode = extension.CreatePipelineNode(null, columnInfo);
            Assert.NotNull(pipelineNode);
        }
 
        [Fact]
        public void BuildMatrixFactorizationPipelineNode()
        {
            var sweepParams = SweepableParams.BuildMatrixFactorizationParams();
            foreach (var sweepParam in sweepParams)
            {
                sweepParam.RawValue = 1;
            }
 
            var pipelineNode = new MatrixFactorizationExtension().CreatePipelineNode(sweepParams, new ColumnInformation());
 
            var expectedJson = @"{
  ""Name"": ""MatrixFactorization"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""NumberOfIterations"": 20,
    ""LearningRate"": 0.01,
    ""ApproximationRank"": 16,
    ""Lambda"": 0.05,
    ""LossFunction"": ""SquareLossOneClass"",
    ""Alpha"": 0.01,
    ""C"": 0.0001,
    ""LabelColumnName"": ""Label"",
    ""MatrixColumnIndexColumnName"": null,
    ""MatrixRowIndexColumnName"": null
  }
}"{
  ""Name"": ""MatrixFactorization"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""NumberOfIterations"": 20,
    ""LearningRate"": 0.01,
    ""ApproximationRank"": 16,
    ""Lambda"": 0.05,
    ""LossFunction"": ""SquareLossOneClass"",
    ""Alpha"": 0.01,
    ""C"": 0.0001,
    ""LabelColumnName"": ""Label"",
    ""MatrixColumnIndexColumnName"": null,
    ""MatrixRowIndexColumnName"": null
  }
}";
            Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
        }
 
        [Fact]
        public void BuildLightGbmPipelineNode()
        {
            var sweepParams = SweepableParams.BuildLightGbmParams();
            foreach (var sweepParam in sweepParams)
            {
                sweepParam.RawValue = 1;
            }
 
            var pipelineNode = new LightGbmBinaryExtension().CreatePipelineNode(sweepParams, new ColumnInformation());
 
            var expectedJson = @"{
  ""Name"": ""LightGbmBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""NumberOfIterations"": 20,
    ""LearningRate"": 1,
    ""NumberOfLeaves"": 1,
    ""MinimumExampleCountPerLeaf"": 10,
    ""UseCategoricalSplit"": false,
    ""HandleMissingValue"": false,
    ""UseZeroAsMissingValue"": false,
    ""MinimumExampleCountPerGroup"": 50,
    ""MaximumCategoricalSplitPointCount"": 16,
    ""CategoricalSmoothing"": 10,
    ""L2CategoricalRegularization"": 0.5,
    ""Booster"": {
      ""Name"": ""GradientBooster.Options"",
      ""Properties"": {
        ""L2Regularization"": 0.5,
        ""L1Regularization"": 0.5
      }
    },
    ""LabelColumnName"": ""Label""
  }
}"{
  ""Name"": ""LightGbmBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""NumberOfIterations"": 20,
    ""LearningRate"": 1,
    ""NumberOfLeaves"": 1,
    ""MinimumExampleCountPerLeaf"": 10,
    ""UseCategoricalSplit"": false,
    ""HandleMissingValue"": false,
    ""UseZeroAsMissingValue"": false,
    ""MinimumExampleCountPerGroup"": 50,
    ""MaximumCategoricalSplitPointCount"": 16,
    ""CategoricalSmoothing"": 10,
    ""L2CategoricalRegularization"": 0.5,
    ""Booster"": {
      ""Name"": ""GradientBooster.Options"",
      ""Properties"": {
        ""L2Regularization"": 0.5,
        ""L1Regularization"": 0.5
      }
    },
    ""LabelColumnName"": ""Label""
  }
}";
            Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
        }
 
        [Fact]
        public void BuildSdcaPipelineNode()
        {
            var sweepParams = SweepableParams.BuildSdcaParams();
            foreach (var sweepParam in sweepParams)
            {
                sweepParam.RawValue = 1;
            }
 
            var pipelineNode = new SdcaLogisticRegressionBinaryExtension().CreatePipelineNode(sweepParams, new ColumnInformation());
            var expectedJson = @"{
  ""Name"": ""SdcaLogisticRegressionBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""L2Regularization"": 1E-07,
    ""L1Regularization"": 0.0,
    ""ConvergenceTolerance"": 0.01,
    ""MaximumNumberOfIterations"": 10,
    ""Shuffle"": true,
    ""BiasLearningRate"": 0.01,
    ""LabelColumnName"": ""Label""
  }
}"{
  ""Name"": ""SdcaLogisticRegressionBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""L2Regularization"": 1E-07,
    ""L1Regularization"": 0.0,
    ""ConvergenceTolerance"": 0.01,
    ""MaximumNumberOfIterations"": 10,
    ""Shuffle"": true,
    ""BiasLearningRate"": 0.01,
    ""LabelColumnName"": ""Label""
  }
}";
            Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
        }
 
        [Fact]
        public void BuildLightGbmPipelineNodeDefaultParams()
        {
            var pipelineNode = new LightGbmBinaryExtension().CreatePipelineNode(
                new List<SweepableParam>(),
                new ColumnInformation());
            var expectedJson = @"{
  ""Name"": ""LightGbmBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""LabelColumnName"": ""Label""
  }
}"{
  ""Name"": ""LightGbmBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""LabelColumnName"": ""Label""
  }
}";
            Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
        }
 
        [Fact]
        public void BuildPipelineNodeWithCustomColumns()
        {
            var columnInfo = new ColumnInformation()
            {
                LabelColumnName = "L",
                ExampleWeightColumnName = "W"
            };
            var sweepParams = SweepableParams.BuildFastForestParams();
            foreach (var sweepParam in sweepParams)
            {
                sweepParam.RawValue = 1;
            }
 
            var pipelineNode = new FastForestBinaryExtension().CreatePipelineNode(sweepParams, columnInfo);
            var expectedJson = @"{
  ""Name"": ""FastForestBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""NumberOfLeaves"": 1,
    ""MinimumExampleCountPerLeaf"": 10,
    ""NumberOfTrees"": 100,
    ""LabelColumnName"": ""L"",
    ""ExampleWeightColumnName"": ""W""
  }
}"{
  ""Name"": ""FastForestBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""NumberOfLeaves"": 1,
    ""MinimumExampleCountPerLeaf"": 10,
    ""NumberOfTrees"": 100,
    ""LabelColumnName"": ""L"",
    ""ExampleWeightColumnName"": ""W""
  }
}";
            Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
        }
 
        [Fact]
        public void BuildDefaultAveragedPerceptronPipelineNode()
        {
            var pipelineNode = new AveragedPerceptronBinaryExtension().CreatePipelineNode(null, new ColumnInformation() { LabelColumnName = "L" });
            var expectedJson = @"{
  ""Name"": ""AveragedPerceptronBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""LabelColumnName"": ""L"",
    ""NumberOfIterations"": 10
  }
}"{
  ""Name"": ""AveragedPerceptronBinary"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""LabelColumnName"": ""L"",
    ""NumberOfIterations"": 10
  }
}";
            Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
        }
 
        [Fact]
        public void BuildOvaPipelineNode()
        {
            var pipelineNode = new FastForestOvaExtension().CreatePipelineNode(null, new ColumnInformation());
            var expectedJson = @"{
  ""Name"": ""Ova"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": null,
  ""OutColumns"": null,
  ""Properties"": {
    ""LabelColumnName"": ""Label"",
    ""BinaryTrainer"": {
      ""Name"": ""FastForestBinary"",
      ""NodeType"": ""Trainer"",
      ""InColumns"": [
        ""Features""
      ],
      ""OutColumns"": [
        ""Score""
      ],
      ""Properties"": {
        ""LabelColumnName"": ""Label""
      }
    }
  }
}"{
  ""Name"": ""Ova"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": null,
  ""OutColumns"": null,
  ""Properties"": {
    ""LabelColumnName"": ""Label"",
    ""BinaryTrainer"": {
      ""Name"": ""FastForestBinary"",
      ""NodeType"": ""Trainer"",
      ""InColumns"": [
        ""Features""
      ],
      ""OutColumns"": [
        ""Score""
      ],
      ""Properties"": {
        ""LabelColumnName"": ""Label""
      }
    }
  }
}";
            Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
        }
 
        [Fact]
        public void BuildFastTreeRankingPipelineNode()
        {
            var columnInfo = new ColumnInformation()
            {
                LabelColumnName = "L",
                GroupIdColumnName = "GId"
            };
            var pipelineNode = new FastTreeRankingExtension().CreatePipelineNode(null, columnInfo);
            var expectedJson = @"{
  ""Name"": ""FastTreeRanking"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""LabelColumnName"": ""L"",
    ""RowGroupColumnName"": ""GId""
  }
}"{
  ""Name"": ""FastTreeRanking"",
  ""NodeType"": ""Trainer"",
  ""InColumns"": [
    ""Features""
  ],
  ""OutColumns"": [
    ""Score""
  ],
  ""Properties"": {
    ""LabelColumnName"": ""L"",
    ""RowGroupColumnName"": ""GId""
  }
}";
            Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
        }
 
        [Fact]
        public void BuildParameterSetLightGbm()
        {
            var props = new Dictionary<string, object>()
            {
                {"NumberOfIterations", 1 },
                {"LearningRate", 1 },
                {"Booster", new CustomProperty() {
                    Name = "GradientBooster.Options",
                    Properties = new Dictionary<string, object>()
                    {
                        {"L2Regularization", 1 },
                        {"L1Regularization", 1 },
                    }
                } },
            };
            var binaryParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmBinary, props);
            var multiParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmMulti, props);
            var regressionParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmRegression, props);
            var rankingParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmRanking, props);
 
            foreach (var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams, rankingParams })
            {
                Assert.Equal(4, paramSet.Count);
                Assert.Equal("1", paramSet["NumberOfIterations"].ValueText);
                Assert.Equal("1", paramSet["LearningRate"].ValueText);
                Assert.Equal("1", paramSet["L2Regularization"].ValueText);
                Assert.Equal("1", paramSet["L1Regularization"].ValueText);
            }
        }
 
        [Fact]
        public void BuildParameterSetSdca()
        {
            var props = new Dictionary<string, object>()
            {
                {"LearningRate", 1 },
            };
 
            var sdcaParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.SdcaLogisticRegressionBinary, props);
 
            Assert.Equal(1, sdcaParams.Count);
            Assert.Equal("1", sdcaParams["LearningRate"].ValueText);
        }
 
        [Fact]
        public void PublicToPrivateTrainerNamesBinaryTest()
        {
            var publicNames = Enum.GetValues(typeof(BinaryClassificationTrainer)).Cast<BinaryClassificationTrainer>();
            var internalNames = TrainerExtensionUtil.GetTrainerNames(publicNames);
            Assert.Equal(publicNames.Distinct().Count(), internalNames.Distinct().Count());
        }
 
        [Fact]
        public void PublicToPrivateTrainerNamesMultiTest()
        {
            var publicNames = Enum.GetValues(typeof(MulticlassClassificationTrainer)).Cast<MulticlassClassificationTrainer>();
            var internalNames = TrainerExtensionUtil.GetTrainerNames(publicNames);
            Assert.Equal(publicNames.Distinct().Count(), internalNames.Distinct().Count());
        }
 
        [Fact]
        public void PublicToPrivateTrainerNamesRegressionTest()
        {
            var publicNames = Enum.GetValues(typeof(RegressionTrainer)).Cast<RegressionTrainer>();
            var internalNames = TrainerExtensionUtil.GetTrainerNames(publicNames);
            Assert.Equal(publicNames.Distinct().Count(), internalNames.Distinct().Count());
        }
 
        [Fact]
        public void PublicToPrivateTrainerNamesRecommendationTest()
        {
            var publicNames = Enum.GetValues(typeof(RecommendationTrainer)).Cast<RecommendationTrainer>();
            var internalNames = TrainerExtensionUtil.GetTrainerNames(publicNames);
            Assert.Equal(publicNames.Distinct().Count(), internalNames.Distinct().Count());
        }
 
        [Fact]
        public void PublicToPrivateTrainerNamesRankingTest()
        {
            var publicNames = Enum.GetValues(typeof(RankingTrainer)).Cast<RankingTrainer>();
            var internalNames = TrainerExtensionUtil.GetTrainerNames(publicNames);
            Assert.Equal(publicNames.Distinct().Count(), internalNames.Distinct().Count());
        }
 
        [Fact]
        public void PublicToPrivateTrainerNamesNullTest()
        {
            var internalNames = TrainerExtensionUtil.GetTrainerNames(null as IEnumerable<BinaryClassificationTrainer>);
            Assert.Null(internalNames);
        }
 
        [Fact]
        public void AllowedTrainersAllowListNullTest()
        {
            var trainers = RecipeInference.AllowedTrainers(new MLContext(1), TaskKind.BinaryClassification, new ColumnInformation(), null);
            Assert.True(trainers.Any());
        }
 
        [Fact]
        public void AllowedTrainersAllowListTest()
        {
            var allowList = new[] { TrainerName.AveragedPerceptronBinary, TrainerName.FastForestBinary };
            var trainers = RecipeInference.AllowedTrainers(new MLContext(1), TaskKind.BinaryClassification, new ColumnInformation(), allowList);
            Assert.Equal(allowList.Count(), trainers.Count());
        }
    }
}