File: RegressionTree.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    /// <summary>
    /// A container base class for exposing <see cref="InternalRegressionTree"/>'s and
    /// <see cref="InternalQuantileRegressionTree"/>'s attributes to users.
    /// This class should not be mutable, so it contains a lot of read-only members.
    /// </summary>
    public abstract class RegressionTreeBase
    {
        /// <summary>
        /// <see cref="RegressionTreeBase"/> is an immutable wrapper over <see cref="_tree"/> for exposing some tree's
        /// attribute to users.
        /// </summary>
        private readonly InternalRegressionTree _tree;
 
        /// <summary>
        /// See <see cref="LeftChild"/>.
        /// </summary>
        private readonly ImmutableArray<int> _lteChild;
        /// <summary>
        /// See <see cref="RightChild"/>.
        /// </summary>
        private readonly ImmutableArray<int> _gtChild;
        /// <summary>
        /// See <see cref="NumericalSplitFeatureIndexes"/>.
        /// </summary>
        private readonly ImmutableArray<int> _numericalSplitFeatureIndexes;
        /// <summary>
        /// See <see cref="NumericalSplitThresholds"/>.
        /// </summary>
        private readonly ImmutableArray<float> _numericalSplitThresholds;
        /// <summary>
        /// See <see cref="CategoricalSplitFlags"/>.
        /// </summary>
        private readonly ImmutableArray<bool> _categoricalSplitFlags;
        /// <summary>
        /// See <see cref="LeafValues"/>.
        /// </summary>
        private readonly ImmutableArray<double> _leafValues;
        /// <summary>
        /// See <see cref="SplitGains"/>.
        /// </summary>
        private readonly ImmutableArray<double> _splitGains;
 
        /// <summary>
        /// <see cref="LeftChild"/>[i] is the i-th node's child index used when
        /// (1) the numerical feature indexed by <see cref="NumericalSplitFeatureIndexes"/>[i] is less than or equal
        /// to the threshold <see cref="NumericalSplitThresholds"/>[i], or
        /// (2) the categorical features indexed by <see cref="GetCategoricalCategoricalSplitFeatureRangeAt(int)"/>'s
        /// returned value with nodeIndex=i is NOT a sub-set of <see cref="GetCategoricalSplitFeaturesAt(int)"/> with
        /// nodeIndex=i.
        /// Note that the case (1) happens only when <see cref="CategoricalSplitFlags"/>[i] is false and otherwise (2)
        /// occurs. A non-negative returned value means a node (i.e., not a leaf); for example, 2 means the 3rd node in
        /// the underlying <see cref="_tree"/>. A negative returned value means a leaf; for example, -1 stands for the
        /// <see langword="~"/>(-1)-th leaf in the underlying <see cref="_tree"/>. Note that <see langword="~"/> is the
        /// bitwise complement operator in C#; for details, see
        /// https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/operators/bitwise-complement-operator.
        /// </summary>
        public IReadOnlyList<int> LeftChild => _lteChild;
 
        /// <summary>
        /// <see cref="RightChild"/>[i] is the i-th node's child index used when the two conditions, (1) and (2),
        /// described in <see cref="LeftChild"/>'s document are not true. Its return value follows the format
        /// used in <see cref="LeftChild"/>.
        /// </summary>
        public IReadOnlyList<int> RightChild => _gtChild;
 
        /// <summary>
        /// <see cref="NumericalSplitFeatureIndexes"/>[i] is the feature index used the splitting function of the
        /// i-th node. This value is valid only if <see cref="CategoricalSplitFlags"/>[i] is false.
        /// </summary>
        public IReadOnlyList<int> NumericalSplitFeatureIndexes => _numericalSplitFeatureIndexes;
 
        /// <summary>
        /// <see cref="NumericalSplitThresholds"/>[i] is the threshold on feature indexed by
        /// <see cref="NumericalSplitFeatureIndexes"/>[i], where i is the i-th node's index
        /// (for example, i is 1 for the 2nd node in <see cref="_tree"/>).
        /// </summary>
        public IReadOnlyList<float> NumericalSplitThresholds => _numericalSplitThresholds;
 
        /// <summary>
        /// Determine the types of splitting function. If <see cref="CategoricalSplitFlags"/>[i] is true, the i-th
        /// node's uses categorical splitting function. Otherwise, traditional numerical split is used.
        /// </summary>
        public IReadOnlyList<bool> CategoricalSplitFlags => _categoricalSplitFlags;
 
        /// <summary>
        /// <see cref="LeafValues"/>[i] is the learned value at the i-th leaf.
        /// </summary>
        public IReadOnlyList<double> LeafValues => _leafValues;
 
        /// <summary>
        /// Return categorical thresholds used at node indexed by nodeIndex. If the considered input feature does NOT
        /// matche any of values returned by <see cref="GetCategoricalSplitFeaturesAt(int)"/>, we call it a
        /// less-than-threshold event and therefore <see cref="LeftChild"/>[nodeIndex] is the child node that input
        /// should go next. The returned value is valid only if <see cref="CategoricalSplitFlags"/>[nodeIndex] is true.
        /// </summary>
        public IReadOnlyList<int> GetCategoricalSplitFeaturesAt(int nodeIndex)
        {
            if (nodeIndex < 0 || nodeIndex >= NumberOfNodes)
                throw Contracts.Except($"The input index, {nodeIndex}, is invalid. Its valid range is from 0 (inclusive) to {NumberOfNodes} (exclusive).");
 
            if (_tree.CategoricalSplitFeatures == null || _tree.CategoricalSplitFeatures[nodeIndex] == null)
                return new List<int>(); // Zero-length vector.
            else
                return _tree.CategoricalSplitFeatures[nodeIndex];
        }
        /// <summary>
        /// Return categorical thresholds' range used at node indexed by nodeIndex. A categorical split at node indexed
        /// by nodeIndex can consider multiple consecutive input features at one time; their range is specified by
        /// <see cref="GetCategoricalCategoricalSplitFeatureRangeAt(int)"/>. The returned value is always a 2-element
        /// array; its 1st element is the starting index and its 2nd element is the endining index of a feature segment.
        /// The returned value is valid only if <see cref="CategoricalSplitFlags"/>[nodeIndex] is true.
        /// </summary>
        public IReadOnlyList<int> GetCategoricalCategoricalSplitFeatureRangeAt(int nodeIndex)
        {
            if (nodeIndex < 0 || nodeIndex >= NumberOfNodes)
                throw Contracts.Except($"The input node index, {nodeIndex}, is invalid. Its valid range is from 0 (inclusive) to {NumberOfNodes} (exclusive).");
 
            if (_tree.CategoricalSplitFeatureRanges == null || _tree.CategoricalSplitFeatureRanges[nodeIndex] == null)
                return new List<int>(); // Zero-length vector.
            else
                return _tree.CategoricalSplitFeatureRanges[nodeIndex];
        }
 
        /// <summary>
        /// The gains obtained by splitting data at nodes. Its i-th value is computed from to the split at the i-th node.
        /// </summary>
        public IReadOnlyList<double> SplitGains => _splitGains;
 
        /// <summary>
        /// Number of leaves in the tree. Note that <see cref="NumberOfLeaves"/> does not take non-leaf nodes into account.
        /// </summary>
        public int NumberOfLeaves => _tree.NumLeaves;
 
        /// <summary>
        /// Number of nodes in the tree. This doesn't include any leaves. For example, a tree with node0->node1,
        /// node0->leaf3, node1->leaf1, node1->leaf2, <see cref="NumberOfNodes"/> and <see cref="NumberOfLeaves"/> should
        /// be 2 and 3, respectively.
        /// </summary>
        // A visualization of the example mentioned in this doc string.
        //         node0
        //         /  \
        //     node1 leaf3
        //     /  \
        // leaf1 leaf2
        // The index of leaf starts with 1 because interally we use "-1" as the 1st leaf's index, "-2" for the 2nd leaf's index, and so on.
        public int NumberOfNodes => _tree.NumNodes;
 
        internal RegressionTreeBase(InternalRegressionTree tree)
        {
            _tree = tree;
 
            _lteChild = ImmutableArray.Create(_tree.LteChild, 0, _tree.NumNodes);
            _gtChild = ImmutableArray.Create(_tree.GtChild, 0, _tree.NumNodes);
 
            _numericalSplitFeatureIndexes = ImmutableArray.Create(_tree.SplitFeatures, 0, _tree.NumNodes);
            _numericalSplitThresholds = ImmutableArray.Create(_tree.RawThresholds, 0, _tree.NumNodes);
            _categoricalSplitFlags = ImmutableArray.Create(_tree.CategoricalSplit, 0, _tree.NumNodes);
            _leafValues = ImmutableArray.Create(_tree.LeafValues, 0, _tree.NumLeaves);
            _splitGains = ImmutableArray.Create(_tree.SplitGains, 0, _tree.NumNodes);
        }
    }
 
    /// <summary>
    /// A container class for exposing <see cref="InternalRegressionTree"/>'s attributes to users.
    /// This class should not be mutable, so it contains a lot of read-only members. Note that
    /// <see cref="RegressionTree"/> is identical to <see cref="RegressionTreeBase"/> but in
    /// another derived class <see cref="QuantileRegressionTree"/> some attributes are added.
    /// </summary>
    public sealed class RegressionTree : RegressionTreeBase
    {
        internal RegressionTree(InternalRegressionTree tree) : base(tree) { }
    }
 
    /// <summary>
    /// A container class for exposing <see cref="InternalQuantileRegressionTree"/>'s attributes to users.
    /// This class should not be mutable, so it contains a lot of read-only members. In addition to
    /// things inherited from <see cref="RegressionTreeBase"/>, we add <see cref="GetLeafSamplesAt(int)"/>
    /// and <see cref="GetLeafSampleWeightsAt(int)"/> to expose (sub-sampled) training labels falling into
    /// the leafIndex-th leaf and their weights.
    /// </summary>
    public sealed class QuantileRegressionTree : RegressionTreeBase
    {
        /// <summary>
        /// Sample labels from training data. <see cref="_leafSamples"/>[i] stores the labels falling into the
        /// i-th leaf.
        /// </summary>
        private readonly double[][] _leafSamples;
 
        /// <summary>
        /// Sample labels' weights from training data. <see cref="_leafSampleWeights"/>[i] stores the weights for
        /// labels falling into the i-th leaf. <see cref="_leafSampleWeights"/>[i][j] is the weight of
        /// <see cref="_leafSamples"/>[i][j].
        /// </summary>
        private readonly double[][] _leafSampleWeights;
 
        /// <summary>
        /// Return the training labels falling into the specified leaf.
        /// </summary>
        /// <param name="leafIndex">The index of the specified leaf.</param>
        /// <returns>Training labels</returns>
        public IReadOnlyList<double> GetLeafSamplesAt(int leafIndex)
        {
            if (leafIndex < 0 || leafIndex >= NumberOfLeaves)
                throw Contracts.Except($"The input leaf index, {leafIndex}, is invalid. Its valid range is from 0 (inclusive) to {NumberOfLeaves} (exclusive).");
 
            // _leafSample always contains valid values assigned in constructor.
            return _leafSamples[leafIndex];
        }
 
        /// <summary>
        /// Return the weights for training labels falling into the specified leaf. If <see cref="GetLeafSamplesAt"/>
        /// and <see cref="GetLeafSampleWeightsAt"/> use the same input, the i-th returned value of this function is
        /// the weight of the i-th label in <see cref="GetLeafSamplesAt"/>.
        /// </summary>
        /// <param name="leafIndex">The index of the specified leaf.</param>
        /// <returns>Training labels' weights</returns>
        public IReadOnlyList<double> GetLeafSampleWeightsAt(int leafIndex)
        {
            if (leafIndex < 0 || leafIndex >= NumberOfLeaves)
                throw Contracts.Except($"The input leaf index, {leafIndex}, is invalid. Its valid range is from 0 (inclusive) to {NumberOfLeaves} (exclusive).");
 
            // _leafSampleWeights always contains valid values assigned in constructor.
            return _leafSampleWeights[leafIndex];
        }
 
        internal QuantileRegressionTree(InternalQuantileRegressionTree tree) : base(tree)
        {
            tree.ExtractLeafSamplesAndTheirWeights(out _leafSamples, out _leafSampleWeights);
        }
    }
}