File: Numeric\Ranking.cs
Web Access
Project: src\test\Microsoft.ML.PerformanceTests\Microsoft.ML.PerformanceTests.csproj (Microsoft.ML.PerformanceTests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.IO;
using BenchmarkDotNet.Attributes;
using Microsoft.ML.Data;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Trainers.LightGbm;
using Microsoft.ML.Transforms;
 
namespace Microsoft.ML.PerformanceTests
{
    [Config(typeof(TrainConfig))]
    public class RankingTrain : BenchmarkBase
    {
        private string _mslrWeb10kValidate;
        private string _mslrWeb10kTrain;
 
        [GlobalSetup]
        public void SetupTrainingSpeedTests()
        {
            _mslrWeb10kValidate = GetBenchmarkDataPathAndEnsureData(TestDatasets.MSLRWeb.validFilename, TestDatasets.MSLRWeb.path);
            _mslrWeb10kTrain = GetBenchmarkDataPathAndEnsureData(TestDatasets.MSLRWeb.trainFilename, TestDatasets.MSLRWeb.path);
 
            if (!File.Exists(_mslrWeb10kValidate))
                throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _mslrWeb10kValidate));
 
            if (!File.Exists(_mslrWeb10kTrain))
                throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _mslrWeb10kTrain));
        }
 
        [Benchmark]
        public void TrainTest_Ranking_MSLRWeb10K_RawNumericFeatures_FastTreeRanking()
        {
            string cmd = @"TrainTest test=" + _mslrWeb10kValidate +
                " eval=RankingEvaluator{t=10}" +
                " data=" + _mslrWeb10kTrain +
                " loader=TextLoader{col=Label:R4:0 col=GroupId:TX:1 col=Features:R4:2-138}" +
                " xf=HashTransform{col=GroupId} xf=NAHandleTransform{col=Features}" +
                " tr=FastTreeRanking{}";
 
            var environment = EnvironmentFactory.CreateRankingEnvironment<RankingEvaluator, TextLoader, HashingTransformer, FastTreeRankingTrainer, FastTreeRankingModelParameters>();
            cmd.ExecuteMamlCommand(environment);
        }
 
        [Benchmark]
        public void TrainTest_Ranking_MSLRWeb10K_RawNumericFeatures_LightGBMRanking()
        {
            string cmd = @"TrainTest test=" + _mslrWeb10kValidate +
                " eval=RankingEvaluator{t=10}" +
                " data=" + _mslrWeb10kTrain +
                " loader=TextLoader{col=Label:R4:0 col=GroupId:TX:1 col=Features:R4:2-138}" +
                " xf=HashTransform{col=GroupId}" +
                " xf=NAHandleTransform{col=Features}" +
                " tr=LightGBMRanking{}";
 
            var environment = EnvironmentFactory.CreateRankingEnvironment<RankingEvaluator, TextLoader, HashingTransformer, LightGbmMulticlassTrainer, OneVersusAllModelParameters>();
            cmd.ExecuteMamlCommand(environment);
        }
    }
 
    public class RankingTest : BenchmarkBase
    {
        private string _mslrWeb10kValidate;
        private string _mslrWeb10kTrain;
        private string _mslrWeb10kTest;
        private string _modelPathMslr;
 
        [GlobalSetup]
        public void SetupScoringSpeedTests()
        {
            _mslrWeb10kTest = GetBenchmarkDataPathAndEnsureData(TestDatasets.MSLRWeb.testFilename, TestDatasets.MSLRWeb.path);
            _mslrWeb10kValidate = GetBenchmarkDataPathAndEnsureData(TestDatasets.MSLRWeb.validFilename, TestDatasets.MSLRWeb.path);
            _mslrWeb10kTrain = GetBenchmarkDataPathAndEnsureData(TestDatasets.MSLRWeb.trainFilename, TestDatasets.MSLRWeb.path);
 
            if (!File.Exists(_mslrWeb10kTest))
                throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _mslrWeb10kTest));
 
            if (!File.Exists(_mslrWeb10kValidate))
                throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _mslrWeb10kValidate));
 
            if (!File.Exists(_mslrWeb10kTrain))
                throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _mslrWeb10kTrain));
 
            _modelPathMslr = Path.Combine(Path.GetDirectoryName(typeof(RankingTest).Assembly.Location), "FastTreeRankingModel.zip");
 
            string cmd = @"TrainTest test=" + _mslrWeb10kValidate +
                " eval=RankingEvaluator{t=10}" +
                " data=" + _mslrWeb10kTrain +
                " loader=TextLoader{col=Label:R4:0 col=GroupId:TX:1 col=Features:R4:2-138}" +
                " xf=HashTransform{col=GroupId}" +
                " xf=NAHandleTransform{col=Features}" +
                " tr=FastTreeRanking{}" +
                " out={" + _modelPathMslr + "}";
 
            var environment = EnvironmentFactory.CreateRankingEnvironment<RankingEvaluator, TextLoader, HashingTransformer, FastTreeRankingTrainer, FastTreeRankingModelParameters>();
            cmd.ExecuteMamlCommand(environment);
        }
 
        [Benchmark]
        public void Test_Ranking_MSLRWeb10K_RawNumericFeatures_FastTreeRanking()
        {
            // This benchmark is profiling bulk scoring speed and not training speed. 
            string cmd = @"Test data=" + _mslrWeb10kTest + " in=" + _modelPathMslr;
 
            var environment = EnvironmentFactory.CreateRankingEnvironment<RankingEvaluator, TextLoader, HashingTransformer, FastTreeRankingTrainer, FastTreeRankingModelParameters>();
            cmd.ExecuteMamlCommand(environment);
        }
    }
}