File: FastTreeClassification.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
 
[assembly: LoadableClass(FastTreeBinaryTrainer.Summary, typeof(FastTreeBinaryTrainer), typeof(FastTreeBinaryTrainer.Options),
    new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
    FastTreeBinaryTrainer.UserNameValue,
    FastTreeBinaryTrainer.LoadNameValue,
    "FastTreeClassification",
    "FastTree",
    "ft",
    FastTreeBinaryTrainer.ShortName,
 
    // FastRank names
    "FastRankBinaryClassification",
    "FastRankBinaryClassificationWrapper",
    "FastRankClassification",
    "fr",
    "btc",
    "frc",
    "fastrank",
    "fastrankwrapper")]
 
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(FastTreeBinaryModelParameters), null, typeof(SignatureLoadModel),
    "FastTree Binary Executor",
    FastTreeBinaryModelParameters.LoaderSignature)]
 
namespace Microsoft.ML.Trainers.FastTree
{
    /// <summary>
    /// Model parameters for <see cref="FastTreeBinaryTrainer"/>.
    /// </summary>
    public sealed class FastTreeBinaryModelParameters :
        TreeEnsembleModelParametersBasedOnRegressionTree
    {
        internal const string LoaderSignature = "FastTreeBinaryExec";
        internal const string RegistrationName = "FastTreeBinaryPredictor";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "FTREE BC",
                // verWrittenCur: 0x00010001, // Initial
                // verWrittenCur: 0x00010002, // _numFeatures serialized
                // verWrittenCur: 0x00010003, // Ini content out of predictor
                // verWrittenCur: 0x00010004, // Add _defaultValueForMissing
                verWrittenCur: 0x00010005, //Categorical splits.
                verReadableCur: 0x00010005,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(FastTreeBinaryModelParameters).Assembly.FullName);
        }
 
        private protected override uint VerNumFeaturesSerialized => 0x00010002;
 
        private protected override uint VerDefaultValueSerialized => 0x00010004;
 
        private protected override uint VerCategoricalSplitSerialized => 0x00010005;
 
        internal FastTreeBinaryModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
            : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
        {
        }
 
        private FastTreeBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx)
            : base(env, RegistrationName, ctx, GetVersionInfo())
        {
        }
 
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);
            ctx.SetVersionInfo(GetVersionInfo());
        }
 
        internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            var predictor = new FastTreeBinaryModelParameters(env, ctx);
            ICalibrator calibrator;
            ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
            if (calibrator == null)
                return predictor;
            return new SchemaBindableCalibratedModelParameters<FastTreeBinaryModelParameters, ICalibrator>(env, predictor, calibrator);
        }
 
        private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
    }
 
    /// <summary>
    /// The <see cref="IEstimator{TTransformer}"/> for training a decision tree binary classification model using FastTree.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    /// To create this trainer, use [FastTree](xref:Microsoft.ML.TreeExtensions.FastTree(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,System.String,System.String,System.String,System.Int32,System.Int32,System.Int32,System.Double))
    /// or [FastTree(Options)](xref:"Microsoft.ML.TreeExtensions.FastTree(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer.Options)).
    ///
    /// [!include[io](~/../docs/samples/docs/api-reference/io-columns-binary-classification.md)]
    ///
    /// ### Trainer Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Machine learning task | Binary classification |
    /// | Is normalization required? | No |
    /// | Is caching required? | No |
    /// | Required NuGet in addition to Microsoft.ML | Microsoft.ML.FastTree |
    /// | Exportable to ONNX | Yes |
    ///
    /// [!include[algorithm](~/../docs/samples/docs/api-reference/algo-details-fasttree.md)]
    /// ]]>
    /// </format>
    /// </remarks>
    /// <seealso cref="TreeExtensions.FastTree(BinaryClassificationCatalog.BinaryClassificationTrainers, string, string, string, int, int, int, double)"/>
    /// <seealso cref="TreeExtensions.FastTree(BinaryClassificationCatalog.BinaryClassificationTrainers, FastTreeBinaryTrainer.Options)"/>
    /// <seealso cref="Options"/>
    public sealed partial class FastTreeBinaryTrainer :
        BoostingFastTreeTrainerBase<FastTreeBinaryTrainer.Options,
        BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>,
        CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>
    {
        /// <summary>
        /// The LoadName for the assembly containing the trainer.
        /// </summary>
        internal const string LoadNameValue = "FastTreeBinaryClassification";
        internal const string UserNameValue = "FastTree (Boosted Trees) Classification";
        internal const string Summary = "Uses a logit-boost boosted tree learner to perform binary classification.";
        internal const string ShortName = "ftc";
 
        private bool[] _trainSetLabels;
        private readonly double _sigmoidParameter;
 
        /// <summary>
        /// Initializes a new instance of <see cref="FastTreeBinaryTrainer"/>
        /// </summary>
        /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
        /// <param name="labelColumnName">The name of the label column.</param>
        /// <param name="featureColumnName">The name of the feature column.</param>
        /// <param name="exampleWeightColumnName">The name for the column containing the example weight.</param>
        /// <param name="learningRate">The learning rate.</param>
        /// <param name="minimumExampleCountPerLeaf">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
        /// <param name="numberOfLeaves">The max number of leaves in each regression tree.</param>
        /// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
        internal FastTreeBinaryTrainer(IHostEnvironment env,
            string labelColumnName = DefaultColumnNames.Label,
            string featureColumnName = DefaultColumnNames.Features,
            string exampleWeightColumnName = null,
            int numberOfLeaves = Defaults.NumberOfLeaves,
            int numberOfTrees = Defaults.NumberOfTrees,
            int minimumExampleCountPerLeaf = Defaults.MinimumExampleCountPerLeaf,
            double learningRate = Defaults.LearningRate)
            : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumnName), featureColumnName, exampleWeightColumnName, null, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf, learningRate)
        {
            // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
            _sigmoidParameter = 2.0 * FastTreeTrainerOptions.LearningRate;
        }
 
        /// <summary>
        /// Initializes a new instance of <see cref="FastTreeBinaryTrainer"/> by using the <see cref="Options"/> class.
        /// </summary>
        /// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
        /// <param name="options">Algorithm advanced settings.</param>
        internal FastTreeBinaryTrainer(IHostEnvironment env, Options options)
            : base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
        {
            // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
            _sigmoidParameter = 2.0 * FastTreeTrainerOptions.LearningRate;
        }
 
        private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
 
        private protected override CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
        {
            Host.CheckValue(context, nameof(context));
            var trainData = context.TrainingSet;
            ValidData = context.ValidationSet;
            TestData = context.TestSet;
 
            using (var ch = Host.Start("Training"))
            {
                ch.CheckValue(trainData, nameof(trainData));
                trainData.CheckBinaryLabel();
                trainData.CheckFeatureFloatVector();
                trainData.CheckOptFloatWeight();
                FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount();
                ConvertData(trainData);
                TrainCore(ch);
            }
 
            // The FastTree binary classification boosting is naturally calibrated to
            // output probabilities when transformed using a scaled logistic function,
            // so transform the scores using that.
 
            var pred = new FastTreeBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions);
            // FastTree's binary classification boosting framework's natural probabilistic interpretation
            // is explained in "From RankNet to LambdaRank to LambdaMART: An Overview" by Chris Burges.
            // The correctness of this scaling depends upon the gradient calculation in
            // BinaryClassificationObjectiveFunction.GetGradientInOneQuery being consistent with the
            // description in section 6 of the paper.
            var cali = new PlattCalibrator(Host, -1 * _sigmoidParameter, 0);
            return new FeatureWeightsCalibratedModelParameters<FastTreeBinaryModelParameters, PlattCalibrator>(Host, pred, cali);
        }
 
        private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
        {
            return new ObjectiveImpl(
                TrainSet,
                _trainSetLabels,
                FastTreeTrainerOptions.LearningRate,
                FastTreeTrainerOptions.Shrinkage,
                _sigmoidParameter,
                FastTreeTrainerOptions.UnbalancedSets,
                FastTreeTrainerOptions.MaximumTreeOutput,
                FastTreeTrainerOptions.GetDerivativesSampleRate,
                FastTreeTrainerOptions.BestStepRankingRegressionTrees,
                FastTreeTrainerOptions.Seed,
                ParallelTraining);
        }
 
        private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
        {
            OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
            if (FastTreeTrainerOptions.UseLineSearch)
            {
                var lossCalculator = new BinaryClassificationTest(optimizationAlgorithm.TrainingScores, _trainSetLabels, _sigmoidParameter);
                // REVIEW: we should makeloss indices an enum in BinaryClassificationTest
                optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, FastTreeTrainerOptions.UnbalancedSets ? 3 /*Unbalanced  sets  loss*/ : 1 /*normal loss*/, FastTreeTrainerOptions.MaximumNumberOfLineSearchSteps, FastTreeTrainerOptions.MinimumStepSize);
            }
            return optimizationAlgorithm;
        }
 
        private IEnumerable<bool> GetClassificationLabelsFromRatings(Dataset set)
        {
            // REVIEW: Historically FastTree has this test as >= 1. TLC however
            // generally uses > 0. Consider changing FastTree to be consistent.
            return set.Ratings.Select(x => x >= 1);
        }
 
        private protected override void PrepareLabels(IChannel ch)
        {
            _trainSetLabels = GetClassificationLabelsFromRatings(TrainSet).ToArray(TrainSet.NumDocs);
            //Here we set regression labels to what is in bin file if the values were not overridden with floats
        }
 
        private protected override Test ConstructTestForTrainingData()
        {
            return new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter);
        }
 
        private protected override void InitializeTests()
        {
            //Always compute training L1/L2 errors
            TrainTest = new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter);
            Tests.Add(TrainTest);
 
            if (ValidSet != null)
            {
                ValidTest = new BinaryClassificationTest(ConstructScoreTracker(ValidSet),
                    GetClassificationLabelsFromRatings(ValidSet).ToArray(), _sigmoidParameter);
                Tests.Add(ValidTest);
            }
 
            //If external label is missing use Rating column for L1/L2 error
            //The values may not make much sense if regression value is not an actual label value
            if (TestSets != null)
            {
                for (int t = 0; t < TestSets.Length; ++t)
                {
                    bool[] labels = GetClassificationLabelsFromRatings(TestSets[t]).ToArray();
                    Tests.Add(new BinaryClassificationTest(ConstructScoreTracker(TestSets[t]), labels, _sigmoidParameter));
                }
            }
 
            if (FastTreeTrainerOptions.EnablePruning && ValidSet != null)
            {
                if (!FastTreeTrainerOptions.UseTolerantPruning)
                {
                    //use simple early stopping condition
                    PruningTest = new TestHistory(ValidTest, 0);
                }
                else
                {
                    //use tolerant stopping condition
                    PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold);
                }
            }
        }
 
        private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>> MakeTransformer(
            CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator> model, DataViewSchema trainSchema)
            => new BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
 
        /// <summary>
        /// Trains a <see cref="FastTreeBinaryTrainer"/> using both training and validation data, returns
        /// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
        /// </summary>
        public BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
            => TrainTransformer(trainData, validationData);
 
        private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            return new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation())),
                new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation(true))),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation()))
            };
        }
 
        internal sealed class ObjectiveImpl : ObjectiveFunctionBase, IStepSearch
        {
            private readonly bool[] _labels;
            private readonly bool _unbalancedSets; //Should we use balanced or unbalanced loss function
            private readonly long _npos;
            private readonly long _nneg;
            private readonly IParallelTraining _parallelTraining;
            private readonly double _sigmoidParameter; // Parameter for scaling the loss
 
            public ObjectiveImpl(
                Dataset trainSet,
                bool[] trainSetLabels,
                double learningRate,
                double shrinkage,
                double sigmoidParameter,
                bool unbalancedSets,
                double maxTreeOutput,
                int getDerivativesSampleRate,
                bool bestStepRankingRegressionTrees,
                int rngSeed,
                IParallelTraining parallelTraining)
                : base(
                    trainSet,
                    learningRate,
                    shrinkage,
                    maxTreeOutput,
                    getDerivativesSampleRate,
                    bestStepRankingRegressionTrees,
                    rngSeed)
            {
                _sigmoidParameter = sigmoidParameter;
                _labels = trainSetLabels;
                _unbalancedSets = unbalancedSets;
                if (_unbalancedSets)
                {
                    BinaryClassificationTest.ComputeExampleCounts(_labels, out _npos, out _nneg);
                    Contracts.Check(_nneg > 0 && _npos > 0, "Only one class in training set.");
                }
                _parallelTraining = parallelTraining;
            }
 
            protected override void GetGradientInOneQuery(int query, int threadIndex)
            {
                int begin = Dataset.Boundaries[query];
                int numDocuments = Dataset.Boundaries[query + 1] - Dataset.Boundaries[query];
 
                double recipNpos = 1.0;
                double recipNneg = 1.0;
 
                if (_unbalancedSets)
                {
                    recipNpos = 1.0 / _npos;
                    recipNneg = 1.0 / _nneg;
                }
                // See "From RankNet to LambdaRank to LambdaMART: An Overview" section 6 for a
                // description of these gradients.
                unsafe
                {
                    fixed (bool* pLabels = _labels)
                    fixed (double* pScores = Scores)
                    fixed (double* pLambdas = Gradient)
                    fixed (double* pWeights = Weights)
                    {
                        for (int i = begin; i < begin + numDocuments; ++i)
                        {
                            int label = pLabels[i] ? 1 : -1;
                            double recip = pLabels[i] ? recipNpos : recipNneg;
                            double response = label * _sigmoidParameter / (1.0 + Math.Exp(label * _sigmoidParameter * pScores[i]));
                            double absResponse = Math.Abs(response);
                            pLambdas[i] = response * recip;
                            pWeights[i] = absResponse * (_sigmoidParameter - absResponse) * recip;
                        }
                    }
                }
            }
 
            public void AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree,
                DocumentPartitioning partitioning, ScoreTracker trainingScores)
            {
                const double epsilon = 1.4e-45;
                double multiplier = LearningRate * Shrinkage;
                double[] means = null;
                if (!BestStepRankingRegressionTrees)
                    means = _parallelTraining.GlobalMean(Dataset, tree, partitioning, Weights, false);
                for (int l = 0; l < tree.NumLeaves; ++l)
                {
                    double output = tree.GetOutput(l);
 
                    if (BestStepRankingRegressionTrees)
                        output *= multiplier;
                    else
                        output = multiplier * (output + epsilon) / (means[l] + epsilon);
 
                    if (output > MaxTreeOutput)
                        output = MaxTreeOutput;
                    else if (output < -MaxTreeOutput)
                        output = -MaxTreeOutput;
                    tree.SetOutput(l, output);
                }
            }
        }
    }
 
    /// <summary>
    /// The Entry Point for the FastTree Binary Classifier.
    /// </summary>
    internal static partial class FastTree
    {
        [TlcModule.EntryPoint(Name = "Trainers.FastTreeBinaryClassifier",
            Desc = FastTreeBinaryTrainer.Summary,
            UserName = FastTreeBinaryTrainer.UserNameValue,
            ShortName = FastTreeBinaryTrainer.ShortName)]
        public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryTrainer.Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainFastTree");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            return TrainerEntryPointsUtils.Train<FastTreeBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
                () => new FastTreeBinaryTrainer(host, input),
                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName),
                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.RowGroupColumnName));
        }
    }
}