File: PipelineSuggesters\PipelineSuggester.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.AutoML
{
    internal static class PipelineSuggester
    {
        private const int TopKTrainers = 3;
 
        public static Pipeline GetNextPipeline(MLContext context,
            IEnumerable<PipelineScore> history,
            DatasetColumnInfo[] columns,
            TaskKind task,
            IChannel logger,
            bool isMaximizingMetric = true)
        {
            var inferredHistory = history.Select(r => SuggestedPipelineRunDetail.FromPipelineRunResult(context, r));
            var nextInferredPipeline = GetNextInferredPipeline(context, inferredHistory, columns, task, isMaximizingMetric, CacheBeforeTrainer.Auto, logger);
            return nextInferredPipeline?.ToPipeline();
        }
 
        public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
            IEnumerable<SuggestedPipelineRunDetail> history,
            DatasetColumnInfo[] columns,
            TaskKind task,
            bool isMaximizingMetric,
            CacheBeforeTrainer cacheBeforeTrainer,
            IChannel logger,
            IEnumerable<TrainerName> trainerAllowList = null)
        {
            var availableTrainers = RecipeInference.AllowedTrainers(context, task,
                ColumnInformationUtil.BuildColumnInfo(columns), trainerAllowList);
            var transforms = TransformInferenceApi.InferTransforms(context, task, columns).ToList();
            var transformsPostTrainer = TransformInferenceApi.InferTransformsPostTrainer(context, task, columns).ToList();
 
            // if we haven't run all pipelines once
            if (history.Count() < availableTrainers.Count())
            {
                return GetNextFirstStagePipeline(context, history, availableTrainers, transforms, transformsPostTrainer, cacheBeforeTrainer);
            }
 
            // get top trainers from stage 1 runs
            var topTrainers = GetTopTrainers(history, availableTrainers, isMaximizingMetric);
 
            // sort top trainers by # of times they've been run, from lowest to highest
            var orderedTopTrainers = OrderTrainersByNumTrials(history, topTrainers);
 
            // keep as hash set of previously visited pipelines
            var visitedPipelines = new HashSet<SuggestedPipeline>(history.Select(h => h.Pipeline));
 
            // iterate over top trainers (from least run to most run),
            // to find next pipeline
            foreach (var trainer in orderedTopTrainers)
            {
                var newTrainer = trainer.Clone();
 
                // repeat until passes or runs out of chances
                const int maxNumberAttempts = 10;
                var count = 0;
                do
                {
                    // sample new hyperparameters for the learner
                    if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric, logger))
                    {
                        // if unable to sample new hyperparameters for the learner
                        // (ie SMAC returned 0 suggestions), break
                        break;
                    }
 
                    var suggestedPipeline = SuggestedPipelineBuilder.Build(context, transforms, transformsPostTrainer, newTrainer, cacheBeforeTrainer);
 
                    // make sure we have not seen pipeline before
                    if (!visitedPipelines.Contains(suggestedPipeline))
                    {
                        return suggestedPipeline;
                    }
                } while (++count <= maxNumberAttempts);
            }
 
            return null;
        }
 
        /// <summary>
        /// Get top trainers from first stage
        /// </summary>
        private static IEnumerable<SuggestedTrainer> GetTopTrainers(IEnumerable<SuggestedPipelineRunDetail> history,
            IEnumerable<SuggestedTrainer> availableTrainers,
            bool isMaximizingMetric)
        {
            // narrow history to first stage runs that succeeded
            history = history.Take(availableTrainers.Count()).Where(x => x.RunSucceeded);
 
            history = history.GroupBy(r => r.Pipeline.Trainer.TrainerName).Select(g => g.First());
            IEnumerable<SuggestedPipelineRunDetail> sortedHistory = history.OrderBy(r => r.Score);
            if (isMaximizingMetric)
            {
                sortedHistory = sortedHistory.Reverse();
            }
            var topTrainers = sortedHistory.Take(TopKTrainers).Select(r => r.Pipeline.Trainer);
            return topTrainers;
        }
 
        private static IEnumerable<SuggestedTrainer> OrderTrainersByNumTrials(IEnumerable<SuggestedPipelineRunDetail> history,
            IEnumerable<SuggestedTrainer> selectedTrainers)
        {
            var selectedTrainerNames = new HashSet<TrainerName>(selectedTrainers.Select(t => t.TrainerName));
            return history.Where(h => selectedTrainerNames.Contains(h.Pipeline.Trainer.TrainerName))
                .GroupBy(h => h.Pipeline.Trainer.TrainerName)
                .OrderBy(x => x.Count())
                .Select(x => x.First().Pipeline.Trainer);
        }
 
        private static SuggestedPipeline GetNextFirstStagePipeline(MLContext context,
            IEnumerable<SuggestedPipelineRunDetail> history,
            IEnumerable<SuggestedTrainer> availableTrainers,
            ICollection<SuggestedTransform> transforms,
            ICollection<SuggestedTransform> transformsPostTrainer,
            CacheBeforeTrainer cacheBeforeTrainer)
        {
            var trainer = availableTrainers.ElementAt(history.Count());
            return SuggestedPipelineBuilder.Build(context, transforms, transformsPostTrainer, trainer, cacheBeforeTrainer);
        }
 
        private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableParam> hps)
        {
            var results = new IValueGenerator[hps.Count()];
 
            for (int i = 0; i < hps.Count(); i++)
            {
                switch (hps.ElementAt(i))
                {
                    case SweepableDiscreteParam dp:
                        var dpArgs = new DiscreteParamArguments()
                        {
                            Name = dp.Name,
                            Values = dp.Options.Select(o => o.ToString()).ToArray()
                        };
                        results[i] = new DiscreteValueGenerator(dpArgs);
                        break;
 
                    case SweepableFloatParam fp:
                        var fpArgs = new FloatParamArguments()
                        {
                            Name = fp.Name,
                            Min = fp.Min,
                            Max = fp.Max,
                            LogBase = fp.IsLogScale,
                        };
                        if (fp.NumSteps.HasValue)
                        {
                            fpArgs.NumSteps = fp.NumSteps.Value;
                        }
                        if (fp.StepSize.HasValue)
                        {
                            fpArgs.StepSize = fp.StepSize.Value;
                        }
                        results[i] = new FloatValueGenerator(fpArgs);
                        break;
 
                    case SweepableLongParam lp:
                        var lpArgs = new LongParamArguments()
                        {
                            Name = lp.Name,
                            Min = lp.Min,
                            Max = lp.Max,
                            LogBase = lp.IsLogScale
                        };
                        if (lp.NumSteps.HasValue)
                        {
                            lpArgs.NumSteps = lp.NumSteps.Value;
                        }
                        if (lp.StepSize.HasValue)
                        {
                            lpArgs.StepSize = lp.StepSize.Value;
                        }
                        results[i] = new LongValueGenerator(lpArgs);
                        break;
                }
            }
            return results;
        }
 
        /// <summary>
        /// Samples new hyperparameters for the trainer, and sets them.
        /// Returns true if success (new hyperparameters were suggested and set). Else, returns false.
        /// </summary>
        private static bool SampleHyperparameters(MLContext context, SuggestedTrainer trainer,
            IEnumerable<SuggestedPipelineRunDetail> history, bool isMaximizingMetric, IChannel logger)
        {
            try
            {
                var sps = ConvertToValueGenerators(trainer.SweepParams);
                var sweeper = new SmacSweeper(context,
                    new SmacSweeper.Arguments
                    {
                        SweptParameters = sps
                    });
 
                IEnumerable<SuggestedPipelineRunDetail> historyToUse = history
                    .Where(r => r.RunSucceeded && r.Pipeline.Trainer.TrainerName == trainer.TrainerName &&
                                r.Pipeline.Trainer.HyperParamSet != null &&
                                r.Pipeline.Trainer.HyperParamSet.Any() &&
                                FloatUtils.IsFinite(r.Score));
 
                // get new set of hyperparameter values
                var proposedParamSet = sweeper.ProposeSweeps(1, historyToUse.Select(h => h.ToRunResult(isMaximizingMetric))).FirstOrDefault();
                if (!proposedParamSet.Any())
                {
                    return false;
                }
 
                // associate proposed parameter set with trainer, so that smart hyperparameter
                // sweepers (like KDO) can map them back.
                trainer.SetHyperparamValues(proposedParamSet);
 
                return true;
            }
            catch (Exception ex)
            {
                logger.Error($"SampleHyperparameters failed with exception: {ex}");
                throw;
            }
        }
    }
}