File: Training\Applications\GradientWrappers.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
namespace Microsoft.ML.Trainers.FastTree
{
    /// <summary>
    /// Trivial weights wrapper. Creates proxy class holding the targets.
    /// </summary>
    internal class TrivialGradientWrapper : IGradientAdjuster
    {
        public string TargetWeightsSetName => "";
 
        public TrivialGradientWrapper() { }
 
        public virtual double[] AdjustTargetAndSetWeights(double[] gradient, ObjectiveFunctionBase objFunction, out double[] targetWeights)
        {
            targetWeights = null;
            return gradient;
        }
    }
 
    /// <summary>
    /// Provides weights used when best regression step option is on.
    /// </summary>
    /// Second-derivatives used as weights in a leaf when one makes Newton-Raphson step (taken in account when regression tree is trained).
    internal class BestStepRegressionGradientWrapper : IGradientAdjuster
    {
        public BestStepRegressionGradientWrapper() { }
 
        public virtual double[] AdjustTargetAndSetWeights(double[] gradient, ObjectiveFunctionBase objFunction, out double[] targetWeights)
        {
            targetWeights = objFunction.Weights;
            return gradient;
        }
    }
 
    /// <summary>
    /// Wraps targets with query weights. Regression tree is built for weighted data, and weights are used for mean
    /// calculation at Newton-Raphson step.
    /// </summary>
    internal class QueryWeightsGradientWrapper : IGradientAdjuster
    {
        public QueryWeightsGradientWrapper()
        {
        }
 
        public virtual double[] AdjustTargetAndSetWeights(double[] gradient, ObjectiveFunctionBase objFunction, out double[] targetWeights)
        {
            double[] weightedTargets = new double[gradient.Length];
            double[] sampleWeights = objFunction.Dataset.SampleWeights;
            for (int i = 0; i < gradient.Length; ++i)
                weightedTargets[i] = gradient[i] * sampleWeights[i];
            targetWeights = sampleWeights;
            return weightedTargets;
        }
    }
 
    /// <summary>
    /// Wraps targets whan both query weights and best step regression options are active.
    /// </summary>
    internal class QueryWeightsBestResressionStepGradientWrapper : IGradientAdjuster
    {
        public QueryWeightsBestResressionStepGradientWrapper()
        {
        }
 
        public virtual double[] AdjustTargetAndSetWeights(double[] gradient, ObjectiveFunctionBase objFunction, out double[] targetWeights)
        {
            double[] weightedTargets = new double[gradient.Length];
            double[] weights = new double[gradient.Length];
            double[] sampleWeights = objFunction.Dataset.SampleWeights;
            for (int i = 0; i < gradient.Length; ++i)
            {
                double queryWeight = sampleWeights[i];
                weightedTargets[i] = gradient[i] * queryWeight;
                weights[i] = objFunction.Weights[i] * queryWeight;
            }
            targetWeights = weights;
            return weightedTargets;
        }
    }
}