File: TreeEnsemble.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
 
namespace Microsoft.ML.Trainers.FastTree
{
    /// <summary>
    /// A list of <see cref="RegressionTreeBase"/>'s derived class. To compute the output value of a
    /// <see cref="TreeEnsemble{T}"/>, we need to compute the output values of all trees in <see cref="Trees"/>,
    /// scale those values via <see cref="TreeWeights"/>, and finally sum the scaled values and <see cref="Bias"/> up.
    /// </summary>
    public abstract class TreeEnsemble<T> where T : RegressionTreeBase
    {
        /// <summary>
        /// When doing prediction, this is a value added to the weighted sum of all <see cref="Trees"/>' outputs.
        /// </summary>
        public double Bias { get; }
 
        /// <summary>
        /// <see cref="TreeWeights"/>[i] is the i-th <see cref="RegressionTreeBase"/>'s weight in <see cref="Trees"/>.
        /// </summary>
        public IReadOnlyList<double> TreeWeights { get; }
 
        /// <summary>
        /// <see cref="Trees"/>[i] is the i-th <see cref="RegressionTreeBase"/> in <see cref="Trees"/>.
        /// </summary>
        public IReadOnlyList<T> Trees { get; }
 
        private protected TreeEnsemble(IEnumerable<T> trees, IEnumerable<double> treeWeights, double bias)
        {
            Bias = bias;
            TreeWeights = treeWeights.ToList();
            Trees = trees.ToList();
        }
    }
 
    public sealed class RegressionTreeEnsemble : TreeEnsemble<RegressionTree>
    {
        internal RegressionTreeEnsemble(IEnumerable<RegressionTree> trees, IEnumerable<double> treeWeights, double bias)
            : base(trees, treeWeights, bias)
        {
        }
    }
 
    public sealed class QuantileRegressionTreeEnsemble : TreeEnsemble<QuantileRegressionTree>
    {
        internal QuantileRegressionTreeEnsemble(IEnumerable<QuantileRegressionTree> trees, IEnumerable<double> treeWeights, double bias)
            : base(trees, treeWeights, bias)
        {
        }
    }
}