File: Training\Parallel\IParallelTraining.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    using LeafSplitCandidates = LeastSquaresRegressionTreeLearner.LeafSplitCandidates;
    using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo;
 
#if USE_SINGLE_PRECISION
    using FloatType = System.Single;
#else
    using FloatType = System.Double;
#endif
 
    /// <summary>
    /// Signature of Parallel trainer.
    /// </summary>
    internal delegate void SignatureParallelTrainer();
 
    /// <summary>
    /// delegate function. This function is implemented in TLC, and called by TLC++. It will find best threshold
    /// from raw histogram data (countByBin, sumTargetsByBin, sumWeightsByBin, numDocsInLeaf, sumTargets, sumWeights)
    /// </summary>
    internal delegate void FindBestThresholdFromRawArrayFun(LeafSplitCandidates leafSplitCandidates, int feature, int flock, int subfeature,
        int[] countByBin, FloatType[] sumTargetsByBin, FloatType[] sumWeightsByBin,
        int numDocsInLeaf, double sumTargets, double sumWeights, double varianceTargets, out SplitInfo bestSplit);
 
    /// <summary>
    /// Interface used for parallel training.
    /// Mainly contains three parts:
    /// 1. interactive with IO: <see cref="GetLocalBinConstructionFeatures" />, <see cref="SyncGlobalBoundary" />.
    ///    Data will be partitioned by rows in Data parallel and Voting Parallel.
    ///    To speed up the find bin process, it let different workers to find bins for different features.
    ///    Then perform global sync up.
    ///    In Feature parallel, every machines holds all data, so this is unneeded.
    /// 2. interactive with TreeLearner: <see cref="InitIteration" />, <see cref="CacheHistogram" />, <see cref="IsNeedFindLocalBestSplit" />,
    ///        <see cref="IsSkipNonSplittableHistogram" />, <see cref="FindGlobalBestSplit" />, <see cref="GetGlobalDataCountInLeaf" />, <see cref="PerformGlobalSplit" />.
    ///    A full process is:
    ///        Use <see cref="InitIteration" /> to alter local active features.
    ///        Use <see cref="GetGlobalDataCountInLeaf" /> to check smaller leaf and larger leaf.
    ///        Use <see cref="CacheHistogram" />, <see cref="IsNeedFindLocalBestSplit" /> and <see cref="IsSkipNonSplittableHistogram" /> to interactive with Feature histograms.
    ///        Use <see cref="FindGlobalBestSplit" /> to sync up global best split
    ///        Use <see cref="PerformGlobalSplit" /> to record global num_data in leaves.
    /// 3. interactive with Application : <see cref="GlobalMean" />.
    ///    Output of leaves is calculated by newton step ( - sum(first_order_gradients) / sum(second_order_gradients)).
    ///    If data is partitioned by row, it needs to a sync up for these sum result.
    ///    So It needs to call this to get the real output of leaves.
    /// </summary>
    internal interface IParallelTraining
    {
        /// <summary>
        /// Initialize the network connection.
        /// </summary>
        void InitEnvironment();
 
        /// <summary>
        /// Finalize the network.
        /// </summary>
        void FinalizeEnvironment();
 
        /// <summary>
        /// Initialize once while construct tree learner.
        /// </summary>
        void InitTreeLearner(Dataset trainData, int maxNumLeaves, int maxCatSplitPoints, ref int minDocInLeaf);
 
        /// <summary>
        /// Finalize while tree learner is freed.
        /// </summary>
        void FinalizeTreeLearner();
 
        /// <summary>
        /// Initialize every time before training a tree.
        /// will alter activeFeatures in Feature parallel.
        /// Because it only need to find threshold for part of features in feature parallel.
        /// </summary>
        void InitIteration(ref bool[] activeFeatures);
 
        /// <summary>
        /// Finalize after trained one tree.
        /// </summary>
        void FinalizeIteration();
 
        /// <summary>
        /// Cache Histogram, it will be used for global aggregate.
        /// Only used in Data parallel and Voting Parallel
        /// </summary>
        void CacheHistogram(bool isSmallerLeaf, int featureIdx, int subfeature, SufficientStatsBase sufficientStatsBase, bool hasWeights);
 
        /// <summary>
        /// Only return False in Data parallel.
        /// Data parallel find best threshold after merged global histograms.
        /// </summary>
        bool IsNeedFindLocalBestSplit();
 
        /// <summary>
        /// True if need to skip non-splittable histogram.
        /// Only will return False in Voting parallel.
        /// That is because local doesn't have global histograms in Voting parallel,
        /// So the information about NonSplittable is not correct, and we cannot skip it.
        /// </summary>
        bool IsSkipNonSplittableHistogram();
 
        /// <summary>
        /// Find best split among machines.
        /// will save result in bestSplits.
        /// </summary>
        void FindGlobalBestSplit(LeafSplitCandidates smallerChildSplitCandidates,
           LeafSplitCandidates largerChildSplitCandidates,
           FindBestThresholdFromRawArrayFun findFunction,
           SplitInfo[] bestSplits);
 
        /// <summary>
        /// Get global num_data on specific leaf.
        /// </summary>
        void GetGlobalDataCountInLeaf(int leafIdx, ref int cnt);
 
        /// <summary>
        /// Used to record the global num_data on leaves.
        /// </summary>
        void PerformGlobalSplit(int leaf, int lteChild, int gtChild, SplitInfo splitInfo);
 
        /// <summary>
        /// Get Global mean on different machines for data partitioning in tree.
        /// Used for calculating leaf output value.
        /// will return a array this is the mean output of all leaves.
        /// </summary>
        double[] GlobalMean(Dataset dataset, InternalRegressionTree tree, DocumentPartitioning partitioning, double[] weights, bool filterZeroLambdas);
 
        /// <summary>
        /// Get indices of features that should be find bin in local.
        /// After construct local boundary, should call <see href="SyncGlobalBoundary" />
        /// to get boundaries for all features.
        /// </summary>
        bool[] GetLocalBinConstructionFeatures(int numFeatures);
 
        /// <summary>
        /// Sync Global feature bucket.
        /// used in Data parallel and Voting parallel.
        /// Data are partitioned by row. To speed up the Global find bin process,
        /// we can let different workers construct Bin Boundary for different features,
        /// then perform a global sync up.
        /// </summary>
        void SyncGlobalBoundary(int numFeatures, int maxBin, Double[][] binUpperBounds);
    }
 
    [TlcModule.ComponentKind("ParallelTraining")]
    internal interface ISupportParallelTraining : IComponentFactory<IParallelTraining>
    {
    }
}