File: TrainerExtensions\TrainerExtensionUtil.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.LightGbm;
 
namespace Microsoft.ML.AutoML
{
    internal enum TrainerName
    {
        AveragedPerceptronBinary,
        AveragedPerceptronOva,
        FastForestBinary,
        FastForestOva,
        FastForestRegression,
        FastTreeBinary,
        FastTreeOva,
        FastTreeRegression,
        FastTreeTweedieRegression,
        LightGbmBinary,
        LightGbmMulti,
        LightGbmRegression,
        LinearSvmBinary,
        LinearSvmOva,
        LbfgsLogisticRegressionBinary,
        LbfgsLogisticRegressionOva,
        LbfgsMaximumEntropyMulti,
        OnlineGradientDescentRegression,
        OlsRegression,
        Ova,
        LbfgsPoissonRegression,
        SdcaLogisticRegressionBinary,
        SdcaMaximumEntropyMulti,
        SdcaRegression,
        SgdCalibratedBinary,
        SgdCalibratedOva,
        SymbolicSgdLogisticRegressionBinary,
        SymbolicSgdLogisticRegressionOva,
        MatrixFactorization,
        ImageClassification,
        LightGbmRanking,
        FastTreeRanking
    }
 
    internal static class TrainerExtensionUtil
    {
        private const string WeightColumn = "ExampleWeightColumnName";
        private const string LabelColumn = "LabelColumnName";
        private const string GroupColumn = "GroupColumnName";
 
        public static T CreateOptions<T>(IEnumerable<SweepableParam> sweepParams, string labelColumn) where T : TrainerInputBaseWithLabel
        {
            var options = Activator.CreateInstance<T>();
            options.LabelColumnName = labelColumn;
            if (sweepParams != null)
            {
                UpdateFields(options, sweepParams);
            }
            return options;
        }
 
        public static T CreateOptions<T>(IEnumerable<SweepableParam> sweepParams) where T : class
        {
            var options = Activator.CreateInstance<T>();
            if (sweepParams != null)
            {
                UpdateFields(options, sweepParams);
            }
            return options;
        }
 
        private static readonly string[] _lightGbmBoosterParamNames = new[] { "L2Regularization", "L1Regularization" };
        private const string LightGbmBoosterPropName = "Booster";
 
        public static TOptions CreateLightGbmOptions<TOptions, TOutput, TTransformer, TModel>(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
            where TOptions : LightGbmTrainerBase<TOptions, TOutput, TTransformer, TModel>.OptionsBase, new()
            where TTransformer : ISingleFeaturePredictionTransformer<TModel>
            where TModel : class
        {
            var options = new TOptions();
            options.LabelColumnName = columnInfo.LabelColumnName;
            options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName;
            options.Booster = new GradientBooster.Options();
            if (sweepParams != null)
            {
                var boosterParams = sweepParams.Where(p => _lightGbmBoosterParamNames.Contains(p.Name));
                var parentArgParams = sweepParams.Except(boosterParams);
                UpdateFields(options, parentArgParams);
                UpdateFields(options.Booster, boosterParams);
            }
            return options;
        }
 
        public static PipelineNode BuildOvaPipelineNode(ITrainerExtension multiExtension, ITrainerExtension binaryExtension,
            IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            var ovaNode = new PipelineNode()
            {
                Name = TrainerName.Ova.ToString(),
                NodeType = PipelineNodeType.Trainer,
                Properties = new Dictionary<string, object>()
                {
                    { LabelColumn, columnInfo.LabelColumnName }
                }
            };
            var binaryNode = binaryExtension.CreatePipelineNode(sweepParams, columnInfo);
            ovaNode.Properties["BinaryTrainer"] = binaryNode;
            return ovaNode;
        }
 
        public static PipelineNode BuildPipelineNode(TrainerName trainerName, IEnumerable<SweepableParam> sweepParams,
            string labelColumn, string weightColumn = null, IDictionary<string, object> additionalProperties = null)
        {
            var properties = BuildBasePipelineNodeProps(sweepParams, labelColumn, weightColumn);
 
            if (additionalProperties != null)
            {
                foreach (var property in additionalProperties)
                {
                    properties[property.Key] = property.Value;
                }
            }
 
            return new PipelineNode(trainerName.ToString(), PipelineNodeType.Trainer, DefaultColumnNames.Features,
                DefaultColumnNames.Score, properties);
        }
 
        public static PipelineNode BuildLightGbmPipelineNode(TrainerName trainerName, IEnumerable<SweepableParam> sweepParams,
            string labelColumn, string weightColumn, string groupColumn)
        {
            return new PipelineNode(trainerName.ToString(), PipelineNodeType.Trainer, DefaultColumnNames.Features,
                DefaultColumnNames.Score, BuildLightGbmPipelineNodeProps(sweepParams, labelColumn, weightColumn, groupColumn));
        }
 
        private static IDictionary<string, object> BuildBasePipelineNodeProps(IEnumerable<SweepableParam> sweepParams,
            string labelColumn, string weightColumn)
        {
            var props = new Dictionary<string, object>();
            if (sweepParams != null)
            {
                foreach (var sweepParam in sweepParams)
                {
                    props[sweepParam.Name] = sweepParam.ProcessedValue();
                }
            }
            props[LabelColumn] = labelColumn;
            if (weightColumn != null)
            {
                props[WeightColumn] = weightColumn;
            }
            return props;
        }
 
        private static IDictionary<string, object> BuildLightGbmPipelineNodeProps(IEnumerable<SweepableParam> sweepParams,
            string labelColumn, string weightColumn, string groupColumn)
        {
            Dictionary<string, object> props = null;
            if (sweepParams == null || !sweepParams.Any())
            {
                props = new Dictionary<string, object>();
            }
            else
            {
                var boosterParams = sweepParams.Where(p => _lightGbmBoosterParamNames.Contains(p.Name));
                var parentArgParams = sweepParams.Except(boosterParams);
 
                var boosterProps = boosterParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue());
                var boosterCustomProp = new CustomProperty("GradientBooster.Options", boosterProps);
 
                props = parentArgParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue());
                props[LightGbmBoosterPropName] = boosterCustomProp;
            }
 
            props[LabelColumn] = labelColumn;
            if (weightColumn != null)
            {
                props[WeightColumn] = weightColumn;
            }
            if (groupColumn != null)
            {
                props[GroupColumn] = groupColumn;
            }
 
            return props;
        }
 
        public static ParameterSet BuildParameterSet(TrainerName trainerName, IDictionary<string, object> props)
        {
            props = props.Where(p => p.Key != LabelColumn && p.Key != WeightColumn)
                .ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
 
            if (trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
                trainerName == TrainerName.LightGbmRegression || trainerName == TrainerName.LightGbmRanking)
            {
                return BuildLightGbmParameterSet(props);
            }
 
            var paramVals = props.Select(p => new StringParameterValue(p.Key, p.Value.ToString()));
            return new ParameterSet(paramVals);
        }
 
        public static ColumnInformation BuildColumnInfo(IDictionary<string, object> props)
        {
            var columnInfo = new ColumnInformation();
 
            columnInfo.LabelColumnName = props[LabelColumn] as string;
 
            props.TryGetValue(WeightColumn, out var weightColumn);
            columnInfo.ExampleWeightColumnName = weightColumn as string;
 
            return columnInfo;
        }
 
        private static ParameterSet BuildLightGbmParameterSet(IDictionary<string, object> props)
        {
            IEnumerable<IParameterValue> parameters;
            if (props == null || !props.Any())
            {
                parameters = new List<IParameterValue>();
            }
            else
            {
                var parentProps = props.Where(p => p.Key != LightGbmBoosterPropName);
                var treeProps = ((CustomProperty)props[LightGbmBoosterPropName]).Properties;
                var allProps = parentProps.Union(treeProps);
                parameters = allProps.Select(p => new StringParameterValue(p.Key, p.Value.ToString()));
            }
            return new ParameterSet(parameters);
        }
 
        private static void SetValue(FieldInfo fi, IComparable value, object obj, Type propertyType)
        {
            if (propertyType == value?.GetType())
                fi.SetValue(obj, value);
            else if (propertyType == typeof(double) && value is float)
                fi.SetValue(obj, Convert.ToDouble(value));
            else if (propertyType == typeof(int) && value is long)
                fi.SetValue(obj, Convert.ToInt32(value));
            else if (propertyType == typeof(long) && value is int)
                fi.SetValue(obj, Convert.ToInt64(value));
        }
 
        /// <summary>
        /// Updates properties of object instance based on the values in sweepParams
        /// </summary>
        public static void UpdateFields(object obj, IEnumerable<SweepableParam> sweepParams)
        {
            foreach (var param in sweepParams)
            {
                try
                {
                    // Only updates property if param.value isn't null and
                    // param has a name of property.
                    if (param.RawValue == null)
                    {
                        continue;
                    }
                    var fi = obj.GetType().GetField(param.Name);
                    var propType = Nullable.GetUnderlyingType(fi.FieldType) ?? fi.FieldType;
 
                    if (param is SweepableDiscreteParam dp)
                    {
                        var optIndex = (int)dp.RawValue;
                        //Contracts.Assert(0 <= optIndex && optIndex < dp.Options.Length, $"Options index out of range: {optIndex}");
                        var option = dp.Options[optIndex].ToString().ToLower();
 
                        // Handle <Auto> string values in sweep params
                        if (option == "auto" || option == "<auto>" || option == "< auto >")
                        {
                            //Check if nullable type, in which case 'null' is the auto value.
                            if (Nullable.GetUnderlyingType(fi.FieldType) != null)
                                fi.SetValue(obj, null);
                            else if (fi.FieldType.IsEnum)
                            {
                                // Check if there is an enum option named Auto
                                var enumDict = fi.FieldType.GetEnumValues().Cast<int>()
                                    .ToDictionary(v => Enum.GetName(fi.FieldType, v), v => v);
                                if (enumDict.ContainsKey("Auto"))
                                    fi.SetValue(obj, enumDict["Auto"]);
                            }
                        }
                        else
                            SetValue(fi, (IComparable)dp.Options[optIndex], obj, propType);
                    }
                    else
                        SetValue(fi, param.RawValue, obj, propType);
                }
                catch (Exception)
                {
                    throw new InvalidOperationException($"Cannot set parameter {param.Name} for {obj.GetType()}");
                }
            }
        }
 
        public static TrainerName GetTrainerName(BinaryClassificationTrainer binaryTrainer)
        {
            switch (binaryTrainer)
            {
                case BinaryClassificationTrainer.FastForest:
                    return TrainerName.FastForestBinary;
                case BinaryClassificationTrainer.FastTree:
                    return TrainerName.FastTreeBinary;
                case BinaryClassificationTrainer.LightGbm:
                    return TrainerName.LightGbmBinary;
                case BinaryClassificationTrainer.LbfgsLogisticRegression:
                    return TrainerName.LbfgsLogisticRegressionBinary;
                case BinaryClassificationTrainer.SdcaLogisticRegression:
                    return TrainerName.SdcaLogisticRegressionBinary;
            }
 
            // never expected to reach here
            throw new NotSupportedException($"{binaryTrainer} not supported");
        }
 
        public static TrainerName GetTrainerName(MulticlassClassificationTrainer multiTrainer)
        {
            switch (multiTrainer)
            {
                case MulticlassClassificationTrainer.FastForestOva:
                    return TrainerName.FastForestOva;
                case MulticlassClassificationTrainer.FastTreeOva:
                    return TrainerName.FastTreeOva;
                case MulticlassClassificationTrainer.LightGbm:
                    return TrainerName.LightGbmMulti;
                case MulticlassClassificationTrainer.LbfgsMaximumEntropy:
                    return TrainerName.LbfgsMaximumEntropyMulti;
                case MulticlassClassificationTrainer.LbfgsLogisticRegressionOva:
                    return TrainerName.LbfgsLogisticRegressionOva;
                case MulticlassClassificationTrainer.SdcaMaximumEntropy:
                    return TrainerName.SdcaMaximumEntropyMulti;
            }
 
            // never expected to reach here
            throw new NotSupportedException($"{multiTrainer} not supported");
        }
 
        public static TrainerName GetTrainerName(RegressionTrainer regressionTrainer)
        {
            switch (regressionTrainer)
            {
                case RegressionTrainer.FastForest:
                    return TrainerName.FastForestRegression;
                case RegressionTrainer.FastTree:
                    return TrainerName.FastTreeRegression;
                case RegressionTrainer.FastTreeTweedie:
                    return TrainerName.FastTreeTweedieRegression;
                case RegressionTrainer.LightGbm:
                    return TrainerName.LightGbmRegression;
                case RegressionTrainer.LbfgsPoissonRegression:
                    return TrainerName.LbfgsPoissonRegression;
                case RegressionTrainer.StochasticDualCoordinateAscent:
                    return TrainerName.SdcaRegression;
            }
 
            // never expected to reach here
            throw new NotSupportedException($"{regressionTrainer} not supported");
        }
 
        public static TrainerName GetTrainerName(RankingTrainer rankingTrainer)
        {
            switch (rankingTrainer)
            {
                case RankingTrainer.FastTreeRanking:
                    return TrainerName.FastTreeRanking;
                case RankingTrainer.LightGbmRanking:
                    return TrainerName.LightGbmRanking;
            }
 
            // never expected to reach here
            throw new NotSupportedException($"{rankingTrainer} not supported");
        }
 
        public static TrainerName GetTrainerName(RecommendationTrainer recommendationTrainer)
        {
            switch (recommendationTrainer)
            {
                case RecommendationTrainer.MatrixFactorization:
                    return TrainerName.MatrixFactorization;
            }
 
            // never expected to reach here
            throw new NotSupportedException($"{recommendationTrainer} not supported");
        }
 
        public static IEnumerable<TrainerName> GetTrainerNames(IEnumerable<BinaryClassificationTrainer> binaryTrainers)
        {
            return binaryTrainers?.Select(t => GetTrainerName(t));
        }
 
        public static IEnumerable<TrainerName> GetTrainerNames(IEnumerable<MulticlassClassificationTrainer> multiTrainers)
        {
            return multiTrainers?.Select(t => GetTrainerName(t));
        }
 
        public static IEnumerable<TrainerName> GetTrainerNames(IEnumerable<RegressionTrainer> regressionTrainers)
        {
            return regressionTrainers?.Select(t => GetTrainerName(t));
        }
 
        public static IEnumerable<TrainerName> GetTrainerNames(IEnumerable<RecommendationTrainer> recommendationTrainers)
        {
            return recommendationTrainers?.Select(t => GetTrainerName(t));
        }
 
        public static IEnumerable<TrainerName> GetTrainerNames(IEnumerable<RankingTrainer> rankingTrainers)
        {
            return rankingTrainers?.Select(t => GetTrainerName(t));
        }
    }
}