File: Training\Applications\ObjectiveFunction.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    internal abstract class ObjectiveFunctionBase
    {
        // buffer for gradient, weights and scores
        protected double[] Gradient;
        protected double[] Scores;
 
        // parameters
        protected double LearningRate;
        protected double Shrinkage;
        protected int GradSamplingRate;
        protected bool BestStepRankingRegressionTrees;
 
        protected double MaxTreeOutput;
        // random number generator
        private readonly Random _rnd;
 
        protected const int QueryThreadChunkSize = 100;
 
        internal readonly Dataset Dataset;
 
        public double[] Weights { get; protected set; }
 
        public ObjectiveFunctionBase(
            Dataset dataset,
            double learningRate,
            double shrinkage,
            double maxTreeOutput,
            int gradSamplingRate,
            bool useBestStepRankingRegressionTree,
            int randomNumberGeneratorSeed)
        {
            Dataset = dataset;
            LearningRate = learningRate;
            Shrinkage = shrinkage;
            MaxTreeOutput = maxTreeOutput;
            GradSamplingRate = gradSamplingRate;
            BestStepRankingRegressionTrees = useBestStepRankingRegressionTree;
            _rnd = new Random(randomNumberGeneratorSeed);
            Gradient = new double[Dataset.NumDocs];
            Weights = new double[Dataset.NumDocs];
        }
 
        public virtual double[] GetGradient(IChannel ch, double[] scores)
        {
            Scores = scores;
            int sampleIndex = _rnd.Next(GradSamplingRate);
            using (Timer.Time(TimerEvent.ObjectiveFunctionGetDerivatives))
            {
                // REVIEW: This partitioning doesn't look optimal.
                // Probably make sense to investigate better ways of splitting data?
                var actions = new Action[(int)Math.Ceiling((double)Dataset.NumQueries / QueryThreadChunkSize)];
                var actionIndex = 0;
                var queue = new ConcurrentQueue<int>(Enumerable.Range(0, BlockingThreadPool.NumThreads));
                // fill the vectors with their correct values, query-by-query
                for (int q = 0; q < Dataset.NumQueries; q += QueryThreadChunkSize)
                {
                    int start = q;
                    actions[actionIndex++] = () =>
                      {
                          var threadIndex = 0;
                          Contracts.Check(queue.TryDequeue(out threadIndex));
                          GetGradientChunk(start, start + Math.Min(QueryThreadChunkSize, Dataset.NumQueries - start), GradSamplingRate, sampleIndex, threadIndex);
                          queue.Enqueue(threadIndex);
                      };
                }
 
                Parallel.Invoke(new ParallelOptions() { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions);
            }
            return Gradient;
        }
 
        protected void GetGradientChunk(int startQuery, int endQuery, int sampleRate, int sampleIndex, int threadIndex)
        {
            for (int i = startQuery; i < endQuery; i++)
            {
                if (i % sampleRate == sampleIndex)
                {
                    GetGradientInOneQuery(i, threadIndex);
                }
            }
        }
 
        protected abstract void GetGradientInOneQuery(int query, int threadIndex);
    }
}