File: Tuner\CostFrugalTuner.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using Microsoft.ML.AutoMLService;
using Microsoft.ML.SearchSpace;
 
namespace Microsoft.ML.AutoML
{
    // an implemetation of "Frugal Optimization for Cost-related Hyperparameters" : https://www.aaai.org/AAAI21Papers/AAAI-10128.WuQ.pdf
    internal class CostFrugalTuner : ITuner
    {
        private readonly RandomNumberGenerator _rng = new RandomNumberGenerator();
        private readonly SearchSpace.SearchSpace _searchSpace;
        private readonly Flow2 _localSearch;
        private readonly Dictionary<int, SearchThread> _searchThreadPool = new Dictionary<int, SearchThread>();
        private int _currentThreadId;
        private readonly Dictionary<int, int> _trialProposedBy = new Dictionary<int, int>();
 
        private readonly double[] _lsBoundMax;
        private readonly double[] _lsBoundMin;
        private bool _initUsed = false;
        private double _bestLoss;
 
        public CostFrugalTuner(AutoMLExperiment.AutoMLExperimentSettings settings, ITrialResultManager trialResultManager = null)
            : this(settings.SearchSpace, settings.SearchSpace.SampleFromFeatureSpace(settings.SearchSpace.Default), trialResultManager?.GetAllTrialResults(), settings.Seed)
        {
        }
 
        public CostFrugalTuner(SearchSpace.SearchSpace searchSpace, Parameter initValue = null, IEnumerable<TrialResult> trialResults = null, int? seed = null)
        {
            _searchSpace = searchSpace;
            _currentThreadId = 0;
            _lsBoundMin = _searchSpace.MappingToFeatureSpace(initValue);
            _lsBoundMax = _searchSpace.MappingToFeatureSpace(initValue);
            _initUsed = false;
            _bestLoss = double.MaxValue;
            if (seed is int s)
            {
                _rng = new RandomNumberGenerator(s);
            }
            _localSearch = new Flow2(searchSpace, initValue, true, rng: _rng);
            if (trialResults != null)
            {
                foreach (var trial in trialResults)
                {
                    Update(trial);
                }
            }
        }
 
        public Parameter Propose(TrialSettings settings)
        {
            var trialId = settings.TrialId;
            Parameter param;
            if (_initUsed)
            {
                var searchThread = _searchThreadPool[_currentThreadId];
                param = searchThread.Suggest(trialId);
                _trialProposedBy[trialId] = _currentThreadId;
            }
            else
            {
                param = _searchSpace.SampleFromFeatureSpace(CreateInitConfigFromAdmissibleRegion());
                _trialProposedBy[trialId] = _currentThreadId;
            }
 
            return param;
        }
 
        public Parameter BestConfig { get; set; }
 
        public void Update(TrialResult result)
        {
            var trialId = result.TrialSettings.TrialId;
            var parameter = result.TrialSettings.Parameter;
            var loss = result.Loss;
 
            if (loss < _bestLoss)
            {
                BestConfig = parameter;
                _bestLoss = loss;
            }
 
            var cost = result.DurationInMilliseconds;
            int threadId = _trialProposedBy.ContainsKey(trialId) ? _trialProposedBy[trialId] : _currentThreadId;
            if (_searchThreadPool.Count == 0)
            {
                _searchThreadPool[_currentThreadId] = _localSearch.CreateSearchThread(parameter, loss, cost);
                _initUsed = true;
                UpdateAdmissibleRegion(_searchSpace.MappingToFeatureSpace(parameter));
            }
            else
            {
                _searchThreadPool[threadId].OnTrialComplete(parameter, loss, cost);
                if (_searchThreadPool[threadId].IsConverged)
                {
                    _searchThreadPool.Remove(threadId);
                    _currentThreadId += 1;
                    _initUsed = false;
                }
            }
        }
 
        private void UpdateAdmissibleRegion(double[] config)
        {
            for (int i = 0; i != config.Length; ++i)
            {
                if (config[i] < _lsBoundMin[i])
                {
                    _lsBoundMin[i] = config[i];
                    continue;
                }
 
                if (config[i] > _lsBoundMax[i])
                {
                    _lsBoundMax[i] = config[i];
                    continue;
                }
            }
        }
 
        private double[] CreateInitConfigFromAdmissibleRegion()
        {
            var res = new double[_lsBoundMax.Length];
            for (int i = 0; i != _lsBoundMax.Length; ++i)
            {
                res[i] = _rng.Uniform(_lsBoundMin[i], _lsBoundMax[i]);
            }
 
            return res;
        }
    }
}