File: WrappedLightGbmBooster.cs
Web Access
Project: src\src\Microsoft.ML.LightGbm\Microsoft.ML.LightGbm.csproj (Microsoft.ML.LightGbm)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
 
namespace Microsoft.ML.Trainers.LightGbm
{
    /// <summary>
    /// Wrapper of Booster object of LightGBM.
    /// </summary>
    internal sealed class Booster : IDisposable
    {
        private readonly bool _hasValid;
        private readonly bool _hasMetric;
 
        public WrappedLightGbmInterface.SafeBoosterHandle Handle { get; private set; }
        public int BestIteration { get; set; }
 
        public Booster(Dictionary<string, object> parameters, Dataset trainset, Dataset validset = null)
        {
            var param = LightGbmInterfaceUtils.JoinParameters(parameters);
            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterCreate(trainset.Handle, param, out var handle));
            Handle = handle;
            if (validset != null)
            {
                LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterAddValidData(Handle, validset.Handle));
                _hasValid = true;
            }
 
            int numEval = 0;
            BestIteration = -1;
            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterGetEvalCounts(Handle, ref numEval));
            // At most one metric in ML.NET.
            Contracts.Assert(numEval <= 1);
            if (numEval == 1)
                _hasMetric = true;
        }
 
        public bool Update()
        {
            int isFinished = 0;
            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterUpdateOneIter(Handle, ref isFinished));
            return isFinished == 1;
        }
 
        public double EvalTrain()
        {
            return Eval(0);
        }
 
        public double EvalValid()
        {
            if (_hasValid)
                return Eval(1);
            else
                return double.NaN;
        }
 
        private unsafe double Eval(int dataIdx)
        {
            if (!_hasMetric)
                return double.NaN;
            int outLen = 0;
            double[] res = new double[1];
            fixed (double* ptr = res)
                LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterGetEval(Handle, dataIdx, ref outLen, ptr));
            return res[0];
        }
 
        [BestFriend]
        internal unsafe string GetModelString()
        {
            int bufLen = 2 << 15;
            byte[] buffer = new byte[bufLen];
            int size = 0;
            fixed (byte* ptr = buffer)
                LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterSaveModelToString(Handle, 0, BestIteration, 0, bufLen, ref size, ptr));
            // If buffer size is not enough, reallocate buffer and get again.
            if (size > bufLen)
            {
                bufLen = size;
                buffer = new byte[bufLen];
                fixed (byte* ptr = buffer)
                    LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterSaveModelToString(Handle, 0, BestIteration, 0, bufLen, ref size, ptr));
            }
            byte[] content = new byte[size];
            Array.Copy(buffer, content, size);
            fixed (byte* ptr = content)
                return LightGbmInterfaceUtils.GetString((IntPtr)ptr);
        }
 
        private static double[] Str2DoubleArray(string str, char delimiter)
        {
            var values = new List<double>();
            foreach (var token in str.Split(delimiter))
            {
                var trimmed = token.Trim().ToLowerInvariant();
 
                if (trimmed.Contains("-inf"))
                    values.Add(double.NegativeInfinity);
                else if (trimmed.Contains("inf"))
                    values.Add(double.PositiveInfinity);
                else if (trimmed.Contains("nan"))
                    values.Add(double.NaN);
                else
                    // The value carried in the trimmed string is not inf, -inf, or nan.
                    // Therefore, double.Parse should be able to generate a valid number from it.
                    // If parsing fails, an exception will be thrown.
                    values.Add(double.Parse(trimmed, CultureInfo.InvariantCulture));
            }
            return values.ToArray();
        }
 
        private static int[] Str2IntArray(string str, char delimiter)
        {
            return str.Split(delimiter).Select(x => int.Parse(x, CultureInfo.InvariantCulture)).ToArray();
        }
 
        private static UInt32[] Str2UIntArray(string str, char delimiter)
        {
            return str.Split(delimiter).Select(x => UInt32.Parse(x, CultureInfo.InvariantCulture)).ToArray();
        }
 
        private static bool GetIsDefaultLeft(UInt32 decisionType)
        {
            // The second bit.
            return (decisionType & 2) > 0;
        }
 
        private static bool GetIsCategoricalSplit(UInt32 decisionType)
        {
            // The first bit.
            return (decisionType & 1) > 0;
        }
 
        private static bool GetHasMissing(UInt32 decisionType)
        {
            // The 3rd and 4th bits.
            return ((decisionType >> 2) & 3) > 0;
        }
 
        private static double[] GetDefalutValue(double[] threshold, UInt32[] decisionType)
        {
            double[] ret = new double[threshold.Length];
            for (int i = 0; i < threshold.Length; ++i)
            {
                if (GetHasMissing(decisionType[i]) && !GetIsCategoricalSplit(decisionType[i]))
                {
                    if (GetIsDefaultLeft(decisionType[i]))
                        ret[i] = threshold[i];
                    else
                        ret[i] = threshold[i] + 1;
                }
            }
            return ret;
        }
 
        private static bool FindInBitset(UInt32[] bits, int start, int end, int pos)
        {
            int i1 = pos / 32;
            if (start + i1 >= end)
                return false;
            int i2 = pos % 32;
            return ((bits[start + i1] >> i2) & 1) > 0;
        }
 
        private static int[] GetCatThresholds(UInt32[] catThreshold, int lowerBound, int upperBound)
        {
            List<int> cats = new List<int>();
            for (int j = lowerBound; j < upperBound; ++j)
            {
                // 32 bits.
                for (int k = 0; k < 32; ++k)
                {
                    int cat = (j - lowerBound) * 32 + k;
                    if (FindInBitset(catThreshold, lowerBound, upperBound, cat))
                        cats.Add(cat);
                }
            }
            return cats.ToArray();
        }
 
        public static int GetNumFeatures(string modelString)
        {
            string[] lines = modelString.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
            // Jump to the "max_feature_idx" value in the file. It's at the beginning.
            int i = 0;
            while (!lines[i].StartsWith("max_feature_idx"))
                i++;
 
            // Stored 0 based in the file, need the actual count so adding 1.
            return int.Parse(lines[i].Split('=')[1]) + 1;
        }
 
        public static Dictionary<string, string> GetParameters(string modelString)
        {
            Dictionary<string, string> parameters = new Dictionary<string, string>();
            string[] lines = modelString.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
 
            // Jump to the parameters section in the file. It's at the very end.
            int i = 0;
            while (!lines[i].StartsWith("parameters"))
                i++;
 
            // Increment once more to get to the first parameter
            i++;
 
            while (i <= lines.Length && !lines[i].StartsWith("end of parameters"))
            {
                var param = lines[i].Substring(1, lines[i].Length - 2).Split(':');
                parameters[param[0]] = param[1].Trim();
                i++;
            }
 
            return parameters;
        }
 
        public static InternalTreeEnsemble GetModel(int[] categoricalFeatureBoundaries, string modelString)
        {
            InternalTreeEnsemble res = new InternalTreeEnsemble();
            string[] lines = modelString.Split('\n');
            int i = 0;
            for (; i < lines.Length;)
            {
                if (lines[i].StartsWith("Tree="))
                {
                    Dictionary<string, string> kvPairs = new Dictionary<string, string>();
                    ++i;
                    while (!lines[i].StartsWith("Tree=") && lines[i].Trim().Length != 0)
                    {
                        string[] kv = lines[i].Split('=');
                        Contracts.Check(kv.Length == 2);
                        kvPairs[kv[0].Trim()] = kv[1].Trim();
                        ++i;
                    }
                    int numberOfLeaves = int.Parse(kvPairs["num_leaves"], CultureInfo.InvariantCulture);
                    int numCat = int.Parse(kvPairs["num_cat"], CultureInfo.InvariantCulture);
                    if (numberOfLeaves > 1)
                    {
                        var leftChild = Str2IntArray(kvPairs["left_child"], ' ');
                        var rightChild = Str2IntArray(kvPairs["right_child"], ' ');
                        var splitFeature = Str2IntArray(kvPairs["split_feature"], ' ');
                        var threshold = Str2DoubleArray(kvPairs["threshold"], ' ');
                        var splitGain = Str2DoubleArray(kvPairs["split_gain"], ' ');
                        var leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        var decisionType = Str2UIntArray(kvPairs["decision_type"], ' ');
                        var defaultValue = GetDefalutValue(threshold, decisionType);
                        var categoricalSplitFeatures = new int[numberOfLeaves - 1][];
                        var categoricalSplit = new bool[numberOfLeaves - 1];
                        if (categoricalFeatureBoundaries != null)
                        {
                            // Add offsets to split features.
                            for (int node = 0; node < numberOfLeaves - 1; ++node)
                                splitFeature[node] = categoricalFeatureBoundaries[splitFeature[node]];
                        }
 
                        if (numCat > 0)
                        {
                            var catBoundaries = Str2IntArray(kvPairs["cat_boundaries"], ' ');
                            var catThreshold = Str2UIntArray(kvPairs["cat_threshold"], ' ');
                            for (int node = 0; node < numberOfLeaves - 1; ++node)
                            {
                                if (GetIsCategoricalSplit(decisionType[node]))
                                {
                                    int catIdx = (int)threshold[node];
                                    var cats = GetCatThresholds(catThreshold, catBoundaries[catIdx], catBoundaries[catIdx + 1]);
                                    categoricalSplitFeatures[node] = new int[cats.Length];
                                    // Convert Cat thresholds to feature indices.
                                    for (int j = 0; j < cats.Length; ++j)
                                        categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j];
 
                                    splitFeature[node] = -1;
                                    categoricalSplit[node] = true;
                                    // Swap left and right child.
                                    int t = leftChild[node];
                                    leftChild[node] = rightChild[node];
                                    rightChild[node] = t;
                                }
                                else
                                {
                                    categoricalSplit[node] = false;
                                }
                            }
                        }
                        InternalRegressionTree tree = InternalRegressionTree.Create(numberOfLeaves, splitFeature, splitGain,
                            threshold.Select(x => (float)(x)).ToArray(), defaultValue.Select(x => (float)(x)).ToArray(), leftChild, rightChild, leafOutput,
                            categoricalSplitFeatures, categoricalSplit);
                        res.AddTree(tree);
                    }
                    else
                    {
                        InternalRegressionTree tree = new InternalRegressionTree(2);
                        var leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        if (leafOutput[0] != 0)
                        {
                            // Convert Constant tree to Two-leaf tree, avoid being filter by TLC.
                            var categoricalSplitFeatures = new int[1][];
                            var categoricalSplit = new bool[1];
                            tree = InternalRegressionTree.Create(2, new int[] { 0 }, new double[] { 0 },
                                new float[] { 0 }, new float[] { 0 }, new int[] { -1 }, new int[] { -2 }, new double[] { leafOutput[0], leafOutput[0] },
                                categoricalSplitFeatures, categoricalSplit);
                        }
                        res.AddTree(tree);
                    }
                }
                else
                    ++i;
            }
            return res;
        }
        #region IDisposable Support
        public void Dispose()
        {
            Handle?.Dispose();
            Handle = null;
        }
        #endregion
    }
}