File: Training\RegressionTreeNodeDocuments.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
 
namespace Microsoft.ML.Trainers.FastTree
{
    // RegressionTreeNodeDocuments represents an association between a node in a regression
    // tree and documents belonging to that node.
    // In other word it is <RegressionTree,DocumentPartitioning, NodeIndex> tuple along with operations
    // defined for the tuple
    // It also hides the fact that node can be either leaf on an interior node behind the hood.
    // The caller can treat interior node (entire subtree) in a same fashion as a leaf node.
    internal class RegressionTreeNodeDocuments
    {
        internal readonly InternalRegressionTree Tree;
        internal readonly DocumentPartitioning Partitioning;
        public readonly int NodeIndex; //Index to a node or leaf within the tree
        private int _documentCount;
        public bool IsLeaf => NodeIndex < 0;
 
        internal RegressionTreeNodeDocuments(InternalRegressionTree tree, DocumentPartitioning partitioning, int nodeIndex)
        {
            Tree = tree;
            Partitioning = partitioning;
            NodeIndex = nodeIndex;
            _documentCount = -1;
        }
        public int GetDocumentCount()
        {
            if (_documentCount == -1)
                _documentCount = Tree.GetNodesLeaves(NodeIndex).Sum(leaf => Partitioning.NumDocsInLeaf(leaf));
            return _documentCount;
        }
 
        // Adds delta to a node output (if node is a subtree all leafs are modified)
        public void UpdateOutputsWithDelta(double delta)
        {
            if (IsLeaf)
                Tree.UpdateOutputWithDelta(~NodeIndex, delta);
            else
                foreach (var node in GetLeaves())
                {
                    node.UpdateOutputsWithDelta(delta);
                }
        }
 
        // Returns collection of leaves of a interior node (or itslef if it is a leaf)
        public IEnumerable<RegressionTreeNodeDocuments> GetLeaves()
        {
            if (IsLeaf)
                return Enumerable.Repeat(this, 1);
            else
                return Tree.GetNodesLeaves(NodeIndex).Select(l => new RegressionTreeNodeDocuments(Tree, Partitioning, ~l));
        }
 
        // Returns a Hashset of leaf indexes of this node
        public HashSet<int> GetDocuments()
        {
            HashSet<int> documentsInNode;
            if (NodeIndex < 0) //It is a leaf
            {
                documentsInNode = new HashSet<int>(Partitioning.DocumentsInLeaf(~NodeIndex));
            }
            else
            {
                documentsInNode = new HashSet<int>();
                foreach (var leaf in Tree.GetNodesLeaves(NodeIndex))
                {
                    var allDocsInLeaf = Partitioning.DocumentsInLeaf(leaf);
                    documentsInNode.UnionWith(allDocsInLeaf);
                }
 
            }
            return documentsInNode;
        }
    }
 
    //RecursiveRegressionTree captures a recursive representation of a tree
    //and inherits from RegressionTreeNodeDocuments (a non-recursive node with documents)
    //The class in most cases would be constructed with node index of 0 and would create
    //entire structure of a full tree accessible with LTENode and GTNode
    //
    //Curently only used for smoothing and defines operations defined in recursive fashion
    //GetWeightedOutput and SmoothLeafOutputs used for smoothing the trees
    internal class RecursiveRegressionTree : RegressionTreeNodeDocuments
    {
        //Left and right children on a regression tree
        public readonly RecursiveRegressionTree LteNode;
        public readonly RecursiveRegressionTree GtNode;
        private double _weightedOutput;
        private int _nodeCount;
 
        internal RecursiveRegressionTree(InternalRegressionTree t, DocumentPartitioning p, int n)
            : base(t, p, n)
        {
            _weightedOutput = double.NaN;
            _nodeCount = int.MaxValue;
            if (!IsLeaf)
            {
                LteNode = new RecursiveRegressionTree(Tree, Partitioning, Tree.GetLteChildForNode(NodeIndex));
                GtNode = new RecursiveRegressionTree(Tree, Partitioning, Tree.GetGtChildForNode(NodeIndex));
            }
        }
 
        // Smoothing of leafs outputs:
        // 0 - no smoothing
        // 1 - maximal smoothing
        public void SmoothLeafOutputs(double parentOutput, double smoothing)
        {
            double myOutput = (1 - smoothing) * GetWeightedOutput() + smoothing * parentOutput;
            if (IsLeaf)
            {
                Tree.SetOutput(~NodeIndex, myOutput);
            }
            else
            {
                LteNode.SmoothLeafOutputs(myOutput, smoothing);
                GtNode.SmoothLeafOutputs(myOutput, smoothing);
            }
        }
 
        // Implementation for cached computation of weighted output
        // (used by smoothing)
        public double GetWeightedOutput()
        {
            if (!double.IsNaN(_weightedOutput))
                return _weightedOutput;
            if (NodeIndex < 0)
                return Tree.GetOutput(~NodeIndex);
 
            int lteDocCount = LteNode.GetDocumentCount();
            int gtCount = GtNode.GetDocumentCount();
 
            _weightedOutput = (lteDocCount * LteNode.GetWeightedOutput() + gtCount * GtNode.GetWeightedOutput())
                 / (lteDocCount + gtCount);
            return _weightedOutput;
        }
 
        // Smoothing of leafs outputs:
        // 0 - no smoothing
        // 1 - maximal smoothing
        public void SmoothLeafOutputs(double parentOutput, double smoothing, int[] documentCount)
        {
            int nodeCount = int.MaxValue;
            double myOutput = (1 - smoothing) * GetWeightedOutput(documentCount, out nodeCount) + smoothing * parentOutput;
            if (IsLeaf)
            {
                Tree.SetOutput(~NodeIndex, myOutput);
            }
            else
            {
                LteNode.SmoothLeafOutputs(myOutput, smoothing);
                GtNode.SmoothLeafOutputs(myOutput, smoothing);
            }
        }
 
        // set document count outside instead of calculated from local partition
        public double GetWeightedOutput(int[] documentCount, out int nodeCount)
        {
            if (!double.IsNaN(_weightedOutput) && _nodeCount != int.MaxValue)
            {
                nodeCount = _nodeCount;
                return _weightedOutput;
            }
 
            if (NodeIndex < 0)
            {
                nodeCount = documentCount[~NodeIndex];
                return Tree.GetOutput(~NodeIndex);
            }
 
            int lteDocCount = int.MaxValue;
            double lteWeight = LteNode.GetWeightedOutput(documentCount, out lteDocCount);
            int gtDocCount = int.MaxValue;
            double gtWeight = GtNode.GetWeightedOutput(documentCount, out gtDocCount);
 
            _weightedOutput = (lteDocCount * lteWeight + gtDocCount * gtWeight) / (lteDocCount + gtDocCount);
            _nodeCount = lteDocCount + gtDocCount;
 
            nodeCount = _nodeCount;
            return _weightedOutput;
        }
 
    }
 
}