File: Training\Parallel\SingleTrainer.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
 
[assembly: LoadableClass(typeof(SingleTrainer),
    null, typeof(SignatureParallelTrainer), "single")]
 
[assembly: EntryPointModule(typeof(SingleTrainerFactory))]
 
namespace Microsoft.ML.Trainers.FastTree
{
    using LeafSplitCandidates = LeastSquaresRegressionTreeLearner.LeafSplitCandidates;
    using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo;
 
    internal sealed class SingleTrainer : IParallelTraining
    {
        void IParallelTraining.CacheHistogram(bool isSmallerLeaf, int featureIdx, int subfeature, SufficientStatsBase sufficientStatsBase, bool hasWeights)
        {
        }
 
        bool IParallelTraining.IsNeedFindLocalBestSplit()
        {
            return true;
        }
 
        void IParallelTraining.FindGlobalBestSplit(LeafSplitCandidates smallerChildSplitCandidates,
            LeafSplitCandidates largerChildSplitCandidates,
            FindBestThresholdFromRawArrayFun findFunction,
            SplitInfo[] bestSplits)
        {
        }
 
        void IParallelTraining.GetGlobalDataCountInLeaf(int leafIdx, ref int cnt)
        {
        }
 
        bool[] IParallelTraining.GetLocalBinConstructionFeatures(int numFeatures)
        {
            return Utils.CreateArray<bool>(numFeatures, true);
        }
 
        double[] IParallelTraining.GlobalMean(Dataset dataset, InternalRegressionTree tree, DocumentPartitioning partitioning, double[] weights, bool filterZeroLambdas)
        {
            double[] means = new double[tree.NumLeaves];
            for (int l = 0; l < tree.NumLeaves; ++l)
            {
                means[l] = partitioning.Mean(weights, dataset.SampleWeights, l, filterZeroLambdas);
            }
            return means;
        }
 
        void IParallelTraining.PerformGlobalSplit(int leaf, int lteChild, int gtChild, SplitInfo splitInfo)
        {
        }
 
        void IParallelTraining.InitIteration(ref bool[] activeFeatures)
        {
        }
 
        void IParallelTraining.InitEnvironment()
        {
        }
 
        void IParallelTraining.InitTreeLearner(Dataset trainData, int maxNumLeaves, int maxCatSplitPoints, ref int minDocInLeaf)
        {
        }
 
        void IParallelTraining.SyncGlobalBoundary(int numFeatures, int maxBin, Double[][] binUpperBounds)
        {
        }
 
        void IParallelTraining.FinalizeEnvironment()
        {
        }
 
        void IParallelTraining.FinalizeTreeLearner()
        {
        }
 
        void IParallelTraining.FinalizeIteration()
        {
        }
 
        bool IParallelTraining.IsSkipNonSplittableHistogram()
        {
            return true;
        }
    }
 
    [TlcModule.Component(Name = "Single", Desc = "Single node machine learning process.")]
    internal sealed class SingleTrainerFactory : ISupportParallelTraining
    {
        public IParallelTraining CreateComponent(IHostEnvironment env) => new SingleTrainer();
    }
}