File: FastTreeRanking.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Security;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
 
// REVIEW: Do we really need all these names?
[assembly: LoadableClass(FastTreeRankingTrainer.Summary, typeof(FastTreeRankingTrainer), typeof(FastTreeRankingTrainer.Options),
    new[] { typeof(SignatureRankerTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
    FastTreeRankingTrainer.UserNameValue,
    FastTreeRankingTrainer.LoadNameValue,
    FastTreeRankingTrainer.ShortName,
 
    // FastRank names
    "FastRankRanking",
    "FastRankRankingWrapper",
    "rank",
    "frrank",
    "btrank")]
 
[assembly: LoadableClass(typeof(FastTreeRankingModelParameters), null, typeof(SignatureLoadModel),
    "FastTree Ranking Executor",
    FastTreeRankingModelParameters.LoaderSignature)]
 
[assembly: LoadableClass(typeof(void), typeof(FastTree), null, typeof(SignatureEntryPointModule), "FastTree")]
 
namespace Microsoft.ML.Trainers.FastTree
{
    /// <summary>
    /// The <see cref="IEstimator{TTransformer}"/> for training a decision tree ranking model using FastTree.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    /// To create this trainer, use [FastTree](xref:Microsoft.ML.TreeExtensions.FastTree(Microsoft.ML.RankingCatalog.RankingTrainers,System.String,System.String,System.String,System.String,System.Int32,System.Int32,System.Int32,System.Double))
    /// or [FastTree(Options)](xref:Microsoft.ML.TreeExtensions.FastTree(Microsoft.ML.RankingCatalog.RankingTrainers,Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer.Options)).
    ///
    /// [!include[io](~/../docs/samples/docs/api-reference/io-columns-ranking.md)]
    ///
    /// ### Trainer Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Machine learning task | Ranking |
    /// | Is normalization required? | No |
    /// | Is caching required? | No |
    /// | Required NuGet in addition to Microsoft.ML | Microsoft.ML.FastTree |
    /// | Exportable to ONNX | No |
    ///
    /// [!include[algorithm](~/../docs/samples/docs/api-reference/algo-details-fasttree.md)]
    /// ]]>
    /// </format>
    /// </remarks>
    /// <seealso cref="TreeExtensions.FastTree(RankingCatalog.RankingTrainers, string, string, string, string, int, int, int, double)"/>
    /// <seealso cref="TreeExtensions.FastTree(RegressionCatalog.RegressionTrainers, FastTreeRegressionTrainer.Options)"/>
    /// <seealso cref="Options"/>
    public sealed partial class FastTreeRankingTrainer
        : BoostingFastTreeTrainerBase<FastTreeRankingTrainer.Options, RankingPredictionTransformer<FastTreeRankingModelParameters>, FastTreeRankingModelParameters>
    {
        internal const string LoadNameValue = "FastTreeRanking";
        internal const string UserNameValue = "FastTree (Boosted Trees) Ranking";
        internal const string Summary = "Trains gradient boosted decision trees to the LambdaRank quasi-gradient.";
        internal const string ShortName = "ftrank";
 
        private IEnsembleCompressor<short> _ensembleCompressor;
        private Test _specialTrainSetTest;
        private TestHistory _firstTestSetHistory;
 
        /// <summary>
        /// The prediction kind for this trainer.
        /// </summary>
        private protected override PredictionKind PredictionKind => PredictionKind.Ranking;
 
        /// <summary>
        /// Initializes a new instance of <see cref="FastTreeRankingTrainer"/>
        /// </summary>
        /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
        /// <param name="labelColumnName">The name of the label column.</param>
        /// <param name="featureColumnName">The name of the feature column.</param>
        /// <param name="rowGroupColumnName">The name for the column containing the group ID. </param>
        /// <param name="exampleWeightColumnName">The name for the column containing the example weight.</param>
        /// <param name="numberOfLeaves">The max number of leaves in each regression tree.</param>
        /// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
        /// <param name="minimumExampleCountPerLeaf">The minimal number of examples allowed in a leaf of a regression tree, out of the subsampled data.</param>
        /// <param name="learningRate">The learning rate.</param>
        internal FastTreeRankingTrainer(IHostEnvironment env,
            string labelColumnName = DefaultColumnNames.Label,
            string featureColumnName = DefaultColumnNames.Features,
            string rowGroupColumnName = DefaultColumnNames.GroupId,
            string exampleWeightColumnName = null,
            int numberOfLeaves = Defaults.NumberOfLeaves,
            int numberOfTrees = Defaults.NumberOfTrees,
            int minimumExampleCountPerLeaf = Defaults.MinimumExampleCountPerLeaf,
            double learningRate = Defaults.LearningRate)
            : base(env, TrainerUtils.MakeR4ScalarColumn(labelColumnName), featureColumnName, exampleWeightColumnName, rowGroupColumnName, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf, learningRate)
        {
            Host.CheckNonEmpty(rowGroupColumnName, nameof(rowGroupColumnName));
        }
 
        /// <summary>
        /// Initializes a new instance of <see cref="FastTreeRankingTrainer"/> by using the <see cref="Options"/> class.
        /// </summary>
        /// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
        /// <param name="options">Algorithm advanced settings.</param>
        internal FastTreeRankingTrainer(IHostEnvironment env, Options options)
        : base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumnName))
        {
        }
 
        private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
        {
            Contracts.Assert(labelCol.IsValid);
 
            Action error =
                () => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "Single or Key", labelCol.GetTypeString());
 
            if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
                error();
            if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single)
                error();
        }
 
        private protected override float GetMaxLabel()
        {
            return GetLabelGains().Length - 1;
        }
 
        private protected override FastTreeRankingModelParameters TrainModelCore(TrainContext context)
        {
            Host.CheckValue(context, nameof(context));
            var trainData = context.TrainingSet;
            ValidData = context.ValidationSet;
            TestData = context.TestSet;
 
            using (var ch = Host.Start("Training"))
            {
                var maxLabel = GetLabelGains().Length - 1;
                ConvertData(trainData);
                TrainCore(ch);
                FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount();
            }
            return new FastTreeRankingModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions);
        }
 
        private Double[] GetLabelGains()
        {
            try
            {
                Host.AssertValue(FastTreeTrainerOptions.CustomGains);
                return FastTreeTrainerOptions.CustomGains;
            }
            catch (Exception ex)
            {
                if (ex is FormatException || ex is OverflowException)
                    throw Host.Except(ex, "Error in the format of custom gains. Inner exception is {0}", ex.Message);
                throw;
            }
        }
 
        private protected override void CheckOptions(IChannel ch)
        {
            if (FastTreeTrainerOptions.CustomGains != null)
            {
                var gains = FastTreeTrainerOptions.CustomGains;
                if (gains.Length < 5)
                {
                    throw ch.ExceptUserArg(nameof(FastTreeTrainerOptions.CustomGains),
                        "Has {0} gain levels. We require at least 5 elements.",
                        gains.Length);
                }
                DcgCalculator.LabelGainMap = gains;
                Dataset.DatasetSkeleton.LabelGainMap = gains;
            }
 
            bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
                FastTreeTrainerOptions.EnablePruning;
 
            if (doEarlyStop)
                ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3,
                    nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 or 3.");
 
            base.CheckOptions(ch);
        }
 
        private protected override void Initialize(IChannel ch)
        {
            base.Initialize(ch);
            if (FastTreeTrainerOptions.CompressEnsemble)
            {
                _ensembleCompressor = new LassoBasedEnsembleCompressor();
                _ensembleCompressor.Initialize(FastTreeTrainerOptions.NumberOfTrees, TrainSet, TrainSet.Ratings, FastTreeTrainerOptions.Seed);
            }
        }
 
        private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
        {
            return new LambdaRankObjectiveFunction(TrainSet, TrainSet.Ratings, FastTreeTrainerOptions, ParallelTraining);
        }
 
        private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
        {
            OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
            if (FastTreeTrainerOptions.UseLineSearch)
            {
                _specialTrainSetTest = new FastNdcgTest(optimizationAlgorithm.TrainingScores, TrainSet.Ratings, FastTreeTrainerOptions.SortingAlgorithm, FastTreeTrainerOptions.EarlyStoppingMetrics);
                optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(_specialTrainSetTest, 0, FastTreeTrainerOptions.MaximumNumberOfLineSearchSteps, FastTreeTrainerOptions.MinimumStepSize);
            }
            return optimizationAlgorithm;
        }
 
        private protected override BaggingProvider CreateBaggingProvider()
        {
            Host.Assert(FastTreeTrainerOptions.BaggingSize > 0);
            return new RankingBaggingProvider(TrainSet, FastTreeTrainerOptions.NumberOfLeaves, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.BaggingExampleFraction);
        }
 
        private protected override void PrepareLabels(IChannel ch)
        {
        }
 
        private protected override Test ConstructTestForTrainingData()
        {
            return new NdcgTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, FastTreeTrainerOptions.SortingAlgorithm);
        }
 
        private protected override void InitializeTests()
        {
            if (FastTreeTrainerOptions.TestFrequency != int.MaxValue)
            {
                AddFullTests();
            }
 
            if (FastTreeTrainerOptions.PrintTestGraph)
            {
                // If FirstTestHistory is null (which means the tests were not initialized due to /tf==infinity)
                // We need initialize first set for graph printing
                // Adding to a tests would result in printing the results after final iteration
                if (_firstTestSetHistory == null)
                {
                    var firstTestSetTest = CreateFirstTestSetTest();
                    _firstTestSetHistory = new TestHistory(firstTestSetTest, 0);
                }
            }
 
            // Tests for early stopping.
            TrainTest = CreateSpecialTrainSetTest();
            if (ValidSet != null)
                ValidTest = CreateSpecialValidSetTest();
 
            if (FastTreeTrainerOptions.PrintTrainValidGraph && FastTreeTrainerOptions.EnablePruning && _specialTrainSetTest == null)
            {
                _specialTrainSetTest = CreateSpecialTrainSetTest();
            }
 
            if (FastTreeTrainerOptions.EnablePruning && ValidTest != null)
            {
                if (!FastTreeTrainerOptions.UseTolerantPruning)
                {
                    //use simple early stopping condition
                    PruningTest = new TestHistory(ValidTest, 0);
                }
                else
                {
                    //use tolerant stopping condition
                    PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold);
                }
            }
        }
 
        private void AddFullTests()
        {
            Tests.Add(CreateStandardTest(TrainSet));
 
            if (ValidSet != null)
            {
                Test test = CreateStandardTest(ValidSet);
                Tests.Add(test);
            }
 
            for (int t = 0; TestSets != null && t < TestSets.Length; ++t)
            {
                Test test = CreateStandardTest(TestSets[t]);
                if (t == 0)
                {
                    _firstTestSetHistory = new TestHistory(test, 0);
                }
 
                Tests.Add(test);
            }
        }
 
        private protected override void PrintIterationMessage(IChannel ch, IProgressChannel pch)
        {
            // REVIEW: Shift to using progress channels to report this information.
#if OLD_TRACE
            // This needs to be executed every iteration.
            if (PruningTest != null)
            {
                if (PruningTest is TestWindowWithTolerance)
                {
                    if (PruningTest.BestIteration != -1)
                        ch.Info("Iteration {0} \t(Best tolerated validation moving average NDCG@{1} {2}:{3:00.00}~{4:00.00})",
                                Ensemble.NumTrees,
                                _args.earlyStoppingMetrics,
                                PruningTest.BestIteration,
                                100 * (PruningTest as TestWindowWithTolerance).BestAverageValue,
                                100 * (PruningTest as TestWindowWithTolerance).CurrentAverageValue);
                    else
                        ch.Info("Iteration {0}", Ensemble.NumTrees);
                }
                else
                {
                    ch.Info("Iteration {0} \t(best validation NDCG@{1} {2}:{3:00.00}>{4:00.00})",
                            Ensemble.NumTrees,
                            _args.earlyStoppingMetrics,
                            PruningTest.BestIteration,
                            100 * PruningTest.BestResult.FinalValue,
                            100 * PruningTest.ComputeTests().First().FinalValue);
                }
            }
            else
                base.PrintIterationMessage(ch, pch);
#else
            base.PrintIterationMessage(ch, pch);
#endif
        }
 
        private protected override void ComputeTests()
        {
            if (_firstTestSetHistory != null)
                _firstTestSetHistory.ComputeTests();
 
            if (_specialTrainSetTest != null)
                _specialTrainSetTest.ComputeTests();
 
            if (PruningTest != null)
                PruningTest.ComputeTests();
        }
 
        private protected override string GetTestGraphLine()
        {
            StringBuilder lineBuilder = new StringBuilder();
 
            lineBuilder.AppendFormat("Eval:\tnet.{0:D8}.ini", Ensemble.NumTrees - 1);
 
            foreach (var r in _firstTestSetHistory.ComputeTests())
            {
                lineBuilder.AppendFormat("\t{0:0.0000}"{0:0.0000}", r.FinalValue);
            }
 
            double trainTestResult = 0.0;
            double validTestResult = 0.0;
 
            // We only print non-zero train&valid graph if earlyStoppingTruncation!=0
            // In case /es is not set, we print 0 for train and valid graph NDCG
            // Let's keeping this behavior for backward compatibility with previous FR version
            // Ideally /graphtv should enforce non-zero /es in the commandline validation
            if (_specialTrainSetTest != null)
            {
                trainTestResult = _specialTrainSetTest.ComputeTests().First().FinalValue;
            }
 
            if (PruningTest != null)
            {
                validTestResult = PruningTest.ComputeTests().First().FinalValue;
            }
 
            lineBuilder.AppendFormat("\t{0:0.0000}\t{1:0.0000}", trainTestResult, validTestResult);
 
            return lineBuilder.ToString();
        }
 
        private protected override void Train(IChannel ch)
        {
            base.Train(ch);
            // Print final last iteration.
            // Note that trainNDCG printed in graph will be from copy of a value from previous iteration
            // and will differ slightly from the proper final value computed by FullTest.
            // We cannot compute the final NDCG here due to the fact we use FastNDCGTestForTrainSet computing NDCG based on label sort saved during gradient computation (and we don;t have gradients for n+1 iteration)
            // Keeping it in sync with original FR code
            PrintTestGraph(ch);
        }
 
        private protected override void CustomizedTrainingIteration(InternalRegressionTree tree)
        {
            Contracts.AssertValueOrNull(tree);
            if (tree != null && FastTreeTrainerOptions.CompressEnsemble)
            {
                double[] trainOutputs = Ensemble.GetTreeAt(Ensemble.NumTrees - 1).GetOutputs(TrainSet);
                _ensembleCompressor.SetTreeScores(Ensemble.NumTrees - 1, trainOutputs);
            }
        }
 
        /// <summary>
        /// Create standard test for dataset.
        /// </summary>
        /// <param name="dataset">dataset used for testing</param>
        /// <returns>standard test for the dataset</returns>
        private Test CreateStandardTest(Dataset dataset)
        {
            if (Utils.Size(dataset.MaxDcg) == 0)
                dataset.Skeleton.RecomputeMaxDcg(10);
 
            return new NdcgTest(
                ConstructScoreTracker(dataset),
                dataset.Ratings,
                FastTreeTrainerOptions.SortingAlgorithm);
        }
 
        /// <summary>
        /// Create the special test for train set.
        /// </summary>
        /// <returns>test for train set</returns>
        private Test CreateSpecialTrainSetTest()
        {
            return new FastNdcgTestForTrainSet(
                OptimizationAlgorithm.TrainingScores,
                OptimizationAlgorithm.ObjectiveFunction as LambdaRankObjectiveFunction,
                TrainSet.Ratings,
                FastTreeTrainerOptions.SortingAlgorithm,
                FastTreeTrainerOptions.EarlyStoppingMetrics);
        }
 
        /// <summary>
        /// Create the special test for valid set.
        /// </summary>
        /// <returns>test for train set</returns>
        private Test CreateSpecialValidSetTest()
        {
            return new FastNdcgTest(
                ConstructScoreTracker(ValidSet),
                ValidSet.Ratings,
                FastTreeTrainerOptions.SortingAlgorithm,
                FastTreeTrainerOptions.EarlyStoppingMetrics);
        }
 
        /// <summary>
        /// Create the test for the first test set.
        /// </summary>
        /// <returns>test for the first test set</returns>
        private Test CreateFirstTestSetTest()
        {
            return CreateStandardTest(TestSets[0]);
        }
 
        /// <summary>
        /// Get the header of test graph
        /// </summary>
        /// <returns>Test graph header</returns>
        private protected override string GetTestGraphHeader()
        {
            StringBuilder headerBuilder = new StringBuilder("Eval:\tFileName\tNDCG@1\tNDCG@2\tNDCG@3\tNDCG@4\tNDCG@5\tNDCG@6\tNDCG@7\tNDCG@8\tNDCG@9\tNDCG@10");
 
            if (FastTreeTrainerOptions.PrintTrainValidGraph)
            {
                headerBuilder.Append("\tNDCG@20\tNDCG@40");
                headerBuilder.AppendFormat(
                    "\nNote: Printing train NDCG@{0} as NDCG@20 and validation NDCG@{0} as NDCG@40..\n",
                    FastTreeTrainerOptions.EarlyStoppingMetrics);
            }
 
            return headerBuilder.ToString();
        }
 
        private protected override RankingPredictionTransformer<FastTreeRankingModelParameters> MakeTransformer(FastTreeRankingModelParameters model, DataViewSchema trainSchema)
        => new RankingPredictionTransformer<FastTreeRankingModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
 
        /// <summary>
        /// Trains a <see cref="FastTreeRankingTrainer"/> using both training and validation data, returns
        /// a <see cref="RankingPredictionTransformer{FastTreeRankingModelParameters}"/>.
        /// </summary>
        public RankingPredictionTransformer<FastTreeRankingModelParameters> Fit(IDataView trainData, IDataView validationData)
            => TrainTransformer(trainData, validationData);
 
        private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            return new[]
           {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation()))
            };
        }
 
        internal sealed class LambdaRankObjectiveFunction : ObjectiveFunctionBase, IStepSearch
        {
            private readonly short[] _labels;
 
            private enum DupeIdInfo
            {
                NoInformation = 0,
                Unique = 1,
                FormatNotSupported = 1000000,
                Code404 = 1000001
            };
 
            // precomputed arrays
            private readonly double[] _inverseMaxDcgt;
            private readonly double[] _discount;
            private readonly int[] _oneTwoThree;
 
            private readonly int[][] _labelCounts;
 
            // reusable memory, technical stuff
            private readonly int[][] _permutationBuffers;
            private readonly DcgPermutationComparer[] _comparers;
 
            //gains
            private readonly double[] _gain;
            private double[] _gainLabels;
 
            // parameters
            private readonly int _maxDcgTruncationLevel;
            private readonly bool _useDcg;
            // A lookup table for the sigmoid used in the lambda calculation
            // Note: Is built for a specific sigmoid parameter, so assumes this will be constant throughout computation
            private double[] _sigmoidTable;
            private double _minScore;       // Computed: range of scores covered in table
            private double _maxScore;
            private double _minSigmoid;
            private double _maxSigmoid;
            private double _scoreToSigmoidTableFactor;
            private const double _expAsymptote = -50;     // exp( x < expAsymptote ) is assumed to be 0
            private const int _sigmoidBins = 1000000;         // Number of bins in the lookup table
 
            // Secondary gains, currently not used in any way.
#pragma warning disable 0649
            private readonly double _secondaryMetricShare;
            private readonly double[] _secondaryInverseMaxDcgt;
            private readonly double[] _secondaryGains;
#pragma warning restore 0649
 
            // Baseline risk.
            private static int _iteration = 0; // This is a static class global member which keeps track of the iterations.
            private double _baselineAlphaCurrent;
 
            // These reusable buffers are used for
            // 1. preprocessing the scores for continuous cost function
            // 2. shifted NDCG
            // 3. max DCG per query
            private readonly double[] _scoresCopy;
            private readonly short[] _labelsCopy;
            private readonly short[] _groupIdToTopLabel;
 
            // parameters
            private readonly double _sigmoidParam;
            private readonly char _costFunctionParam;
            private readonly bool _filterZeroLambdas;
 
            private readonly bool _distanceWeight2;
            private readonly bool _normalizeQueryLambdas;
            private readonly bool _useShiftedNdcg;
            private readonly IParallelTraining _parallelTraining;
 
            // Used for training NDCG calculation
            // Keeps track of labels of top 3 documents per query
            public short[][] TrainQueriesTopLabels;
 
            public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options options, IParallelTraining parallelTraining)
                : base(trainset,
                    options.LearningRate,
                    options.Shrinkage,
                    options.MaximumTreeOutput,
                    options.GetDerivativesSampleRate,
                    options.BestStepRankingRegressionTrees,
                    options.Seed)
            {
 
                _labels = labels;
                TrainQueriesTopLabels = new short[Dataset.NumQueries][];
                for (int q = 0; q < Dataset.NumQueries; ++q)
                    TrainQueriesTopLabels[q] = new short[3];
 
                _labelCounts = new int[Dataset.NumQueries][];
                int relevancyLevel = DcgCalculator.LabelGainMap.Length;
                for (int q = 0; q < Dataset.NumQueries; ++q)
                    _labelCounts[q] = new int[relevancyLevel];
 
                // precomputed arrays
                _maxDcgTruncationLevel = options.NdcgTruncationLevel;
                _useDcg = options.UseDcg;
                if (_useDcg)
                {
                    _inverseMaxDcgt = new double[Dataset.NumQueries];
                    for (int q = 0; q < Dataset.NumQueries; ++q)
                        _inverseMaxDcgt[q] = 1.0;
                }
                else
                {
                    _inverseMaxDcgt = DcgCalculator.MaxDcg(_labels, Dataset.Boundaries, _maxDcgTruncationLevel, _labelCounts);
                    for (int q = 0; q < Dataset.NumQueries; ++q)
                        _inverseMaxDcgt[q] = 1.0 / _inverseMaxDcgt[q];
                }
 
                _discount = new double[Dataset.MaxDocsPerQuery];
                FillDiscounts(options.PositionDiscountFreeform);
 
                _oneTwoThree = new int[Dataset.MaxDocsPerQuery];
                for (int d = 0; d < Dataset.MaxDocsPerQuery; ++d)
                    _oneTwoThree[d] = d;
 
                // reusable resources
                int numThreads = BlockingThreadPool.NumThreads;
                _comparers = new DcgPermutationComparer[numThreads];
                for (int i = 0; i < numThreads; ++i)
                    _comparers[i] = DcgPermutationComparerFactory.GetDcgPermutationFactory(options.SortingAlgorithm);
 
                _permutationBuffers = new int[numThreads][];
                for (int i = 0; i < numThreads; ++i)
                    _permutationBuffers[i] = new int[Dataset.MaxDocsPerQuery];
 
                _gain = Dataset.DatasetSkeleton.LabelGainMap;
                FillGainLabels();
 
                #region parameters
                _sigmoidParam = options.LearningRate;
                _costFunctionParam = options.CostFunctionParam;
                _distanceWeight2 = options.DistanceWeight2;
                _normalizeQueryLambdas = options.NormalizeQueryLambdas;
 
                _useShiftedNdcg = options.ShiftedNdcg;
                _filterZeroLambdas = options.FilterZeroLambdas;
                #endregion
 
                _scoresCopy = new double[Dataset.NumDocs];
                _labelsCopy = new short[Dataset.NumDocs];
                _groupIdToTopLabel = new short[Dataset.NumDocs];
 
                FillSigmoidTable(_sigmoidParam);
#if OLD_DATALOAD
            SetupSecondaryGains(cmd);
#endif
                _parallelTraining = parallelTraining;
            }
 
#if OLD_DATALOAD
        private void SetupSecondaryGains(Arguments args)
        {
            _secondaryGains = null;
            _secondaryMetricShare = args.secondaryMetricShare;
            _secondaryIsolabelExclusive = args.secondaryIsolabelExclusive;
            if (_secondaryMetricShare != 0.0)
            {
                _secondaryGains = Dataset.Skeleton.GetData<double>("SecondaryGains");
                if (_secondaryGains == null)
                {
                    _secondaryMetricShare = 0.0;
                    return;
                }
                _secondaryInverseMaxDCGT = DCGCalculator.MaxDCG(_secondaryGains, Dataset.Boundaries,
                    new int[] { args.lambdaMartMaxTruncation })[0].Select(d => 1.0 / d).ToArray();
            }
        }
#endif
            private void FillSigmoidTable(double sigmoidParam)
            {
                // minScore is such that 2*sigmoidParam*score is < expAsymptote if score < minScore
                _minScore = _expAsymptote / sigmoidParam / 2;
                _maxScore = -_minScore;
 
                _sigmoidTable = new double[_sigmoidBins];
                for (int i = 0; i < _sigmoidBins; i++)
                {
                    double score = (_maxScore - _minScore) / _sigmoidBins * i + _minScore;
                    if (score > 0.0)
                        _sigmoidTable[i] = 2.0 - 2.0 / (1.0 + Math.Exp(-2.0 * sigmoidParam * score));
                    else
                        _sigmoidTable[i] = 2.0 / (1.0 + Math.Exp(2.0 * sigmoidParam * score));
                }
                _scoreToSigmoidTableFactor = _sigmoidBins / (_maxScore - _minScore);
                _minSigmoid = _sigmoidTable[0];
                _maxSigmoid = _sigmoidTable.Last();
            }
 
            private void IgnoreNonBestDuplicates(short[] labels, double[] scores, int[] order, UInt32[] dupeIds, int begin, int numDocuments)
            {
                if (dupeIds == null || dupeIds.Length == 0)
                {
                    return;
                }
 
                // Reset top label for all groups
                for (int i = begin; i < begin + numDocuments; ++i)
                {
                    _groupIdToTopLabel[i] = -1;
                }
 
                for (int i = 0; i < numDocuments; ++i)
                {
                    Contracts.Check(0 <= order[i] && order[i] < numDocuments, "the index to document exceeds range");
 
                    int index = begin + order[i];
 
                    UInt32 group = dupeIds[index];
                    if (group == (UInt32)DupeIdInfo.Code404 || group == (UInt32)DupeIdInfo.FormatNotSupported ||
                        group == (UInt32)DupeIdInfo.Unique || group == (UInt32)DupeIdInfo.NoInformation)
                    {
                        continue;
                    }
 
                    // group starts from 2 (since 0 is unknown and 1 is unique)
                    Contracts.Check(2 <= group && group < numDocuments + 2, "dupeId group exceeds range");
 
                    UInt32 groupIndex = (UInt32)begin + group - 2;
 
                    if (_groupIdToTopLabel[groupIndex] != -1)
                    {
                        // this is the second+ occurrence of a result
                        // of the same duplicate group, so:
                        // - disconsider when applying the cost function
                        //
                        // Only do this if the rating of this dupe is worse or equal,
                        // otherwise we want this dupe to be pushed to the top
                        // so we keep it
                        if (labels[index] <= _groupIdToTopLabel[groupIndex])
                        {
                            labels[index] = 0;
                            scores[index] = double.MinValue;
                        }
                    }
                    else
                    {
                        _groupIdToTopLabel[groupIndex] = labels[index];
                    }
                }
            }
 
            public override double[] GetGradient(IChannel ch, double[] scores)
            {
                _baselineAlphaCurrent = 0.0;
                double[] grads = base.GetGradient(ch, scores);
                _iteration++;
                return grads;
            }
 
            protected override void GetGradientInOneQuery(int query, int threadIndex)
            {
                int begin = Dataset.Boundaries[query];
                int numDocuments = Dataset.Boundaries[query + 1] - Dataset.Boundaries[query];
 
                Array.Clear(Gradient, begin, numDocuments);
                Array.Clear(Weights, begin, numDocuments);
 
                double inverseMaxDcg = _inverseMaxDcgt[query];
                double secondaryInverseMaxDcg = _secondaryMetricShare == 0 ? 0.0 : _secondaryInverseMaxDcgt[query];
 
                int[] permutation = _permutationBuffers[threadIndex];
 
                short[] labels = _labels;
                double[] scoresToUse = Scores;
 
                if (_useShiftedNdcg)
                {
                    // Copy the labels for this query
                    Array.Copy(_labels, begin, _labelsCopy, begin, numDocuments);
                    labels = _labelsCopy;
                }
 
                if (_costFunctionParam == 'c' || _useShiftedNdcg)
                {
                    // Copy the scores for this query
                    Array.Copy(Scores, begin, _scoresCopy, begin, numDocuments);
                    scoresToUse = _scoresCopy;
                }
 
                // Keep track of top 3 labels for later use
                //GetTopQueryLabels(query, permutation, false);
 
                double lambdaSum = 0;
 
                unsafe
                {
                    fixed (int* pPermutation = permutation)
                    fixed (short* pLabels = labels)
                    fixed (double* pScores = scoresToUse)
                    fixed (double* pLambdas = Gradient)
                    fixed (double* pWeights = Weights)
                    fixed (double* pDiscount = _discount)
                    fixed (double* pGain = _gain)
                    fixed (double* pGainLabels = _gainLabels)
                    fixed (double* pSigmoidTable = _sigmoidTable)
                    fixed (double* pSecondaryGains = _secondaryGains)
                    fixed (int* pOneTwoThree = _oneTwoThree)
                    {
                        // calculates the permutation that orders "scores" in descending order, without modifying "scores"
                        Array.Copy(_oneTwoThree, permutation, numDocuments);
 
                        if (IntArray.UseFastTreeNative)
                        {
                            PermutationSort(permutation, scoresToUse, labels, numDocuments, begin);
                            // Get how far about baseline our current
                            double baselineDcgGap = 0.0;
                            //baselineDCGGap = ((new Random(query)).NextDouble() * 2 - 1)/inverseMaxDCG; // THIS IS EVIL CODE REMOVE LATER
                            // Keep track of top 3 labels for later use
                            GetTopQueryLabels(query, permutation, true);
 
                            if (_useShiftedNdcg)
                            {
                                // Set non-best (rank-wise) duplicates to be ignored. Set Score to MinValue, Label to 0
                                IgnoreNonBestDuplicates(labels, scoresToUse, permutation, Dataset.DupeIds, begin, numDocuments);
                            }
 
                            int numActualResults = numDocuments;
 
                            // If the const function is ContinuousWeightedRanknet, update output scores
                            if (_costFunctionParam == 'c')
                            {
                                for (int i = begin; i < begin + numDocuments; ++i)
                                {
                                    if (pScores[i] == double.MinValue)
                                    {
                                        numActualResults--;
                                    }
                                    else
                                    {
                                        pScores[i] = pScores[i] * (1.0 - pLabels[i] * 1.0 / (20.0 * Dataset.DatasetSkeleton.LabelGainMap.Length));
                                    }
                                }
                            }
 
                            // Continuous cost function and shifted NDCG require a re-sort and recomputation of maxDCG
                            // (Change of scores in the former and scores and labels in the latter)
                            if (!_useDcg && (_costFunctionParam == 'c' || _useShiftedNdcg))
                            {
                                PermutationSort(permutation, scoresToUse, labels, numDocuments, begin);
                                inverseMaxDcg = 1.0 / DcgCalculator.MaxDcgQuery(labels, begin, numDocuments, numDocuments, _labelCounts[query]);
                            }
                            // A constant related to secondary labels, which does not exist in the current codebase.
                            const bool secondaryIsolabelExclusive = false;
                            GetDerivatives(numDocuments, begin, pPermutation, pLabels,
                                    pScores, pLambdas, pWeights, pDiscount,
                                    inverseMaxDcg, pGainLabels,
                                    _secondaryMetricShare, secondaryIsolabelExclusive, secondaryInverseMaxDcg, pSecondaryGains,
                                    pSigmoidTable, _minScore, _maxScore, _sigmoidTable.Length, _scoreToSigmoidTableFactor,
                                    _costFunctionParam, _distanceWeight2, numActualResults, &lambdaSum, double.MinValue,
                                    _baselineAlphaCurrent, baselineDcgGap);
                        }
                        else
                        {
                            if (_useShiftedNdcg || _costFunctionParam == 'c' || _distanceWeight2 || _normalizeQueryLambdas)
                            {
                                throw new Exception("Shifted NDCG / ContinuousWeightedRanknet / distanceWeight2 / normalized lambdas are only supported by unmanaged code");
                            }
 
                            var comparer = _comparers[threadIndex];
                            comparer.Scores = scoresToUse;
                            comparer.Labels = labels;
                            comparer.ScoresOffset = begin;
                            comparer.LabelsOffset = begin;
                            Array.Sort(permutation, 0, numDocuments, comparer);
 
                            // go over all pairs
                            double scoreHighMinusLow;
                            double lambdaP;
                            double weightP;
                            double deltaNdcgP;
                            for (int i = 0; i < numDocuments; ++i)
                            {
                                int high = begin + pPermutation[i];
                                if (pLabels[high] == 0)
                                    continue;
                                double deltaLambdasHigh = 0;
                                double deltaWeightsHigh = 0;
 
                                for (int j = 0; j < numDocuments; ++j)
                                {
                                    // only consider pairs with different labels, where "high" has a higher label than "low"
                                    if (i == j)
                                        continue;
                                    int low = begin + pPermutation[j];
                                    if (pLabels[high] <= pLabels[low])
                                        continue;
 
                                    // calculate the lambdaP for this pair
                                    scoreHighMinusLow = pScores[high] - pScores[low];
 
                                    if (scoreHighMinusLow <= _minScore)
                                        lambdaP = _minSigmoid;
                                    else if (scoreHighMinusLow >= _maxScore)
                                        lambdaP = _maxSigmoid;
                                    else
                                        lambdaP = _sigmoidTable[(int)((scoreHighMinusLow - _minScore) * _scoreToSigmoidTableFactor)];
 
                                    weightP = lambdaP * (2.0 - lambdaP);
 
                                    // calculate the deltaNDCGP for this pair
                                    deltaNdcgP =
                                        (pGain[pLabels[high]] - pGain[pLabels[low]]) *
                                        Math.Abs((pDiscount[i] - pDiscount[j])) *
                                        inverseMaxDcg;
 
                                    // update lambdas and weights
                                    deltaLambdasHigh += lambdaP * deltaNdcgP;
                                    pLambdas[low] -= lambdaP * deltaNdcgP;
                                    deltaWeightsHigh += weightP * deltaNdcgP;
                                    pWeights[low] += weightP * deltaNdcgP;
                                }
                                pLambdas[high] += deltaLambdasHigh;
                                pWeights[high] += deltaWeightsHigh;
                            }
                        }
 
                        if (_normalizeQueryLambdas)
                        {
                            if (lambdaSum > 0)
                            {
                                double normFactor = (10 * Math.Log(1 + lambdaSum)) / lambdaSum;
 
                                for (int i = begin; i < begin + numDocuments; ++i)
                                {
                                    pLambdas[i] = pLambdas[i] * normFactor;
                                    pWeights[i] = pWeights[i] * normFactor;
                                }
                            }
                        }
                    }
                }
            }
 
            void IStepSearch.AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree, DocumentPartitioning partitioning,
                                            ScoreTracker trainingScores)
            {
                const double epsilon = 1.4e-45;
                double[] means = null;
                if (!BestStepRankingRegressionTrees)
                    means = _parallelTraining.GlobalMean(Dataset, tree, partitioning, Weights, _filterZeroLambdas);
                for (int l = 0; l < tree.NumLeaves; ++l)
                {
                    double output = tree.LeafValue(l);
                    if (!BestStepRankingRegressionTrees)
                        output = (output + epsilon) / (2.0 * means[l] + epsilon);
 
                    if (output > MaxTreeOutput)
                        output = MaxTreeOutput;
                    else if (output < -MaxTreeOutput)
                        output = -MaxTreeOutput;
 
                    tree.SetLeafValue(l, output);
                }
            }
 
            private void FillDiscounts(string positionDiscountFreeform)
            {
                if (positionDiscountFreeform == null)
                {
                    for (int d = 0; d < Dataset.MaxDocsPerQuery; ++d)
                        _discount[d] = 1.0 / Math.Log(2.0 + d);
                }
            }
 
            private void FillGainLabels()
            {
                _gainLabels = new double[Dataset.NumDocs];
                for (int i = 0; i < Dataset.NumDocs; i++)
                {
                    _gainLabels[i] = _gain[_labels[i]];
                }
            }
 
            // Keep track of top 3 labels for later use.
            private void GetTopQueryLabels(int query, int[] permutation, bool bAlreadySorted)
            {
                int numDocuments = Dataset.Boundaries[query + 1] - Dataset.Boundaries[query];
                int begin = Dataset.Boundaries[query];
 
                if (!bAlreadySorted)
                {
                    // calculates the permutation that orders "scores" in descending order, without modifying "scores"
                    Array.Copy(_oneTwoThree, permutation, numDocuments);
                    PermutationSort(permutation, Scores, _labels, numDocuments, begin);
                }
 
                for (int i = 0; i < 3 && i < numDocuments; ++i)
                    TrainQueriesTopLabels[query][i] = _labels[begin + permutation[i]];
            }
 
            private static void PermutationSort(int[] permutation, double[] scores, short[] labels, int numDocs, int shift)
            {
                Contracts.AssertValue(permutation);
                Contracts.AssertValue(scores);
                Contracts.AssertValue(labels);
                Contracts.Assert(numDocs > 0);
                Contracts.Assert(shift >= 0);
                Contracts.Assert(scores.Length - numDocs >= shift);
                Contracts.Assert(labels.Length - numDocs >= shift);
 
                Array.Sort(permutation, 0, numDocs,
                    Comparer<int>.Create((x, y) =>
                    {
                        if (scores[shift + x] > scores[shift + y])
                            return -1;
                        if (scores[shift + x] < scores[shift + y])
                            return 1;
                        if (labels[shift + x] < labels[shift + y])
                            return -1;
                        if (labels[shift + x] > labels[shift + y])
                            return 1;
                        return x - y;
                    }));
            }
 
            [DllImport("FastTreeNative", EntryPoint = "C_GetDerivatives", CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
            private static extern unsafe void GetDerivatives(
                int numDocuments, int begin, int* pPermutation, short* pLabels,
                double* pScores, double* pLambdas, double* pWeights, double* pDiscount,
                double inverseMaxDcg, double* pGainLabels,
                double secondaryMetricShare, [MarshalAs(UnmanagedType.U1)] bool secondaryExclusive, double secondaryInverseMaxDcg, double* pSecondaryGains,
                double* lambdaTable, double minScore, double maxScore,
                int lambdaTableLength, double scoreToLambdaTableFactor,
                char costFunctionParam, [MarshalAs(UnmanagedType.U1)] bool distanceWeight2, int numActualDocuments,
                double* pLambdaSum, double doubleMinValue, double alphaRisk, double baselineVersusCurrentDcg);
 
        }
    }
 
    /// <summary>
    /// Model parameters for <see cref="FastTreeRankingTrainer"/>.
    /// </summary>
    public sealed class FastTreeRankingModelParameters : TreeEnsembleModelParametersBasedOnRegressionTree
    {
        internal const string LoaderSignature = "FastTreeRankerExec";
        internal const string RegistrationName = "FastTreeRankingPredictor";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "FTREE RA",
                // verWrittenCur: 0x00010001, // Initial
                // verWrittenCur: 0x00010002, // _numFeatures serialized
                // verWrittenCur: 0x00010003, // Ini content out of predictor
                // verWrittenCur: 0x00010004, // Add _defaultValueForMissing
                verWrittenCur: 0x00010005, // Categorical splits.
                verReadableCur: 0x00010004,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(FastTreeRankingModelParameters).Assembly.FullName);
        }
 
        private protected override uint VerNumFeaturesSerialized => 0x00010002;
 
        private protected override uint VerDefaultValueSerialized => 0x00010004;
 
        private protected override uint VerCategoricalSplitSerialized => 0x00010005;
 
        internal FastTreeRankingModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
            : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
        {
        }
 
        private FastTreeRankingModelParameters(IHostEnvironment env, ModelLoadContext ctx)
            : base(env, RegistrationName, ctx, GetVersionInfo())
        {
        }
 
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);
            ctx.SetVersionInfo(GetVersionInfo());
        }
 
        internal static FastTreeRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            return new FastTreeRankingModelParameters(env, ctx);
        }
 
        private protected override PredictionKind PredictionKind => PredictionKind.Ranking;
    }
 
    internal static partial class FastTree
    {
        [TlcModule.EntryPoint(Name = "Trainers.FastTreeRanker",
            Desc = FastTreeRankingTrainer.Summary,
            UserName = FastTreeRankingTrainer.UserNameValue,
            ShortName = FastTreeRankingTrainer.ShortName)]
        public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, FastTreeRankingTrainer.Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainFastTree");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            return TrainerEntryPointsUtils.Train<FastTreeRankingTrainer.Options, CommonOutputs.RankingOutput>(host, input,
                () => new FastTreeRankingTrainer(host, input),
                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName),
                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.RowGroupColumnName));
        }
    }
}