File: Utils\RegressionTreeBaseUtils.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
 
namespace Microsoft.ML.FastTree.Utils
{
    internal class RegressionTreeBaseUtils
    {
        /// <summary>
        /// Utility method used to represent a tree ensemble as an <see cref="IDataView"/>.
        /// Every row in the <see cref="IDataView"/> corresponds to a node in the tree ensemble. The columns are the fields for each node.
        /// The column TreeID specifies which tree the node belongs to. The <see cref="QuantileRegressionTree"/> gets
        /// special treatment since it has some additional fields (<see cref="QuantileRegressionTree.GetLeafSamplesAt(int)"/>
        /// and <see cref="QuantileRegressionTree.GetLeafSampleWeightsAt(int)"/>).
        /// </summary>
        public static IDataView RegressionTreeEnsembleAsIDataView(IHost host, double bias, IReadOnlyList<double> treeWeights, IReadOnlyList<RegressionTreeBase> trees)
        {
            var builder = new ArrayDataViewBuilder(host);
            var numberOfRows = trees.Select(tree => tree.NumberOfNodes).Sum() + trees.Select(tree => tree.NumberOfLeaves).Sum();
 
            var treeWeightsList = new List<double>();
            var treeId = new List<int>();
            var isLeaf = new List<ReadOnlyMemory<char>>();
            var leftChild = new List<int>();
            var rightChild = new List<int>();
            var numericalSplitFeatureIndexes = new List<int>();
            var numericalSplitThresholds = new List<float>();
            var categoricalSplitFlags = new List<bool>();
            var leafValues = new List<double>();
            var splitGains = new List<double>();
            var categoricalSplitFeatures = new List<VBuffer<int>>();
            var categoricalCategoricalSplitFeatureRange = new List<VBuffer<int>>();
 
            for (int i = 0; i < trees.Count; i++)
            {
                // TreeWeights column. The TreeWeight value will be repeated for all the notes in the same tree in the IDataView.
                treeWeightsList.AddRange(Enumerable.Repeat(treeWeights[i], trees[i].NumberOfNodes + trees[i].NumberOfLeaves));
 
                // Tree id indicates which tree the node belongs to.
                treeId.AddRange(Enumerable.Repeat(i, trees[i].NumberOfNodes + trees[i].NumberOfLeaves));
 
                // IsLeaf column indicates if node is a leaf node.
                isLeaf.AddRange(Enumerable.Repeat(new ReadOnlyMemory<char>("Tree node".ToCharArray()), trees[i].NumberOfNodes));
                isLeaf.AddRange(Enumerable.Repeat(new ReadOnlyMemory<char>("Leaf node".ToCharArray()), trees[i].NumberOfLeaves));
 
                // LeftChild column.
                leftChild.AddRange(trees[i].LeftChild.AsEnumerable());
                leftChild.AddRange(Enumerable.Repeat(0, trees[i].NumberOfLeaves));
 
                // RightChild column.
                rightChild.AddRange(trees[i].RightChild.AsEnumerable());
                rightChild.AddRange(Enumerable.Repeat(0, trees[i].NumberOfLeaves));
 
                // NumericalSplitFeatureIndexes column.
                numericalSplitFeatureIndexes.AddRange(trees[i].NumericalSplitFeatureIndexes.AsEnumerable());
                numericalSplitFeatureIndexes.AddRange(Enumerable.Repeat(0, trees[i].NumberOfLeaves));
 
                // NumericalSplitThresholds column.
                numericalSplitThresholds.AddRange(trees[i].NumericalSplitThresholds.AsEnumerable());
                numericalSplitThresholds.AddRange(Enumerable.Repeat(0f, trees[i].NumberOfLeaves));
 
                // CategoricalSplitFlags column.
                categoricalSplitFlags.AddRange(trees[i].CategoricalSplitFlags.AsEnumerable());
                categoricalSplitFlags.AddRange(Enumerable.Repeat(false, trees[i].NumberOfLeaves));
 
                // LeafValues column.
                leafValues.AddRange(Enumerable.Repeat(0d, trees[i].NumberOfNodes));
                leafValues.AddRange(trees[i].LeafValues.AsEnumerable());
 
                // SplitGains column.
                splitGains.AddRange(trees[i].SplitGains.AsEnumerable());
                splitGains.AddRange(Enumerable.Repeat(0d, trees[i].NumberOfLeaves));
 
                for (int j = 0; j < trees[i].NumberOfNodes; j++)
                {
                    // CategoricalSplitFeatures column.
                    var categoricalSplitFeaturesArray = trees[i].GetCategoricalSplitFeaturesAt(j).ToArray();
                    categoricalSplitFeatures.Add(new VBuffer<int>(categoricalSplitFeaturesArray.Length, categoricalSplitFeaturesArray));
                    var len = trees[i].GetCategoricalSplitFeaturesAt(j).ToArray().Length;
 
                    // CategoricalCategoricalSplitFeatureRange column.
                    var categoricalCategoricalSplitFeatureRangeArray = trees[i].GetCategoricalCategoricalSplitFeatureRangeAt(j).ToArray();
                    categoricalCategoricalSplitFeatureRange.Add(new VBuffer<int>(categoricalCategoricalSplitFeatureRangeArray.Length, categoricalCategoricalSplitFeatureRangeArray));
                    len = trees[i].GetCategoricalCategoricalSplitFeatureRangeAt(j).ToArray().Length;
                }
 
                categoricalSplitFeatures.AddRange(Enumerable.Repeat(new VBuffer<int>(), trees[i].NumberOfLeaves));
                categoricalCategoricalSplitFeatureRange.AddRange(Enumerable.Repeat(new VBuffer<int>(), trees[i].NumberOfLeaves));
            }
 
            // Bias column. This will be a repeated value for all rows in the resulting IDataView.
            builder.AddColumn("Bias", NumberDataViewType.Double, Enumerable.Repeat(bias, numberOfRows).ToArray());
            builder.AddColumn("TreeWeights", NumberDataViewType.Double, treeWeightsList.ToArray());
            builder.AddColumn("TreeID", NumberDataViewType.Int32, treeId.ToArray());
            builder.AddColumn("IsLeaf", TextDataViewType.Instance, isLeaf.ToArray());
            builder.AddColumn(nameof(RegressionTreeBase.LeftChild), NumberDataViewType.Int32, leftChild.ToArray());
            builder.AddColumn(nameof(RegressionTreeBase.RightChild), NumberDataViewType.Int32, rightChild.ToArray());
            builder.AddColumn(nameof(RegressionTreeBase.NumericalSplitFeatureIndexes), NumberDataViewType.Int32, numericalSplitFeatureIndexes.ToArray());
            builder.AddColumn(nameof(RegressionTreeBase.NumericalSplitThresholds), NumberDataViewType.Single, numericalSplitThresholds.ToArray());
            builder.AddColumn(nameof(RegressionTreeBase.CategoricalSplitFlags), BooleanDataViewType.Instance, categoricalSplitFlags.ToArray());
            builder.AddColumn(nameof(RegressionTreeBase.LeafValues), NumberDataViewType.Double, leafValues.ToArray());
            builder.AddColumn(nameof(RegressionTreeBase.SplitGains), NumberDataViewType.Double, splitGains.ToArray());
            builder.AddColumn("CategoricalSplitFeatures", NumberDataViewType.Int32, categoricalSplitFeatures.ToArray());
            builder.AddColumn("CategoricalCategoricalSplitFeatureRange", NumberDataViewType.Int32, categoricalCategoricalSplitFeatureRange.ToArray());
 
            // If the input tree array is a quantile regression tree we need to add two more columns.
            var quantileTrees = trees as IReadOnlyList<QuantileRegressionTree>;
            if (quantileTrees != null)
            {
                // LeafSamples column.
                var leafSamples = new List<VBuffer<double>>();
 
                // LeafSampleWeights column.
                var leafSampleWeights = new List<VBuffer<double>>();
                for (int i = 0; i < quantileTrees.Count; i++)
                {
                    leafSamples.AddRange(Enumerable.Repeat(new VBuffer<double>(), quantileTrees[i].NumberOfNodes));
                    leafSampleWeights.AddRange(Enumerable.Repeat(new VBuffer<double>(), quantileTrees[i].NumberOfNodes));
                    for (int j = 0; j < quantileTrees[i].NumberOfLeaves; j++)
                    {
                        var leafSamplesArray = quantileTrees[i].GetLeafSamplesAt(j).ToArray();
                        leafSamples.Add(new VBuffer<double>(leafSamplesArray.Length, leafSamplesArray));
                        var len = quantileTrees[i].GetLeafSamplesAt(j).ToArray().Length;
 
                        var leafSampleWeightsArray = quantileTrees[i].GetLeafSampleWeightsAt(j).ToArray();
                        leafSampleWeights.Add(new VBuffer<double>(leafSampleWeightsArray.Length, leafSampleWeightsArray));
                        len = quantileTrees[i].GetLeafSampleWeightsAt(j).ToArray().Length;
                    }
                }
 
                builder.AddColumn("LeafSamples", NumberDataViewType.Double, leafSamples.ToArray());
                builder.AddColumn("LeafSampleWeights", NumberDataViewType.Double, leafSampleWeights.ToArray());
            }
 
            var data = builder.GetDataView();
            return data;
        }
 
    }
}