File: Experiment\SuggestedTrainer.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Trainers;
 
namespace Microsoft.ML.AutoML
{
    internal class SuggestedTrainer
    {
        public IEnumerable<SweepableParam> SweepParams { get; }
        public TrainerName TrainerName { get; }
        public ParameterSet HyperParamSet { get; set; }
 
        private readonly MLContext _mlContext;
        private readonly ITrainerExtension _trainerExtension;
        private readonly ColumnInformation _columnInfo;
 
        internal SuggestedTrainer(MLContext mlContext, ITrainerExtension trainerExtension,
            ColumnInformation columnInfo,
            ParameterSet hyperParamSet = null)
        {
            _mlContext = mlContext;
            _trainerExtension = trainerExtension;
            _columnInfo = columnInfo;
            SweepParams = _trainerExtension.GetHyperparamSweepRanges();
            TrainerName = TrainerExtensionCatalog.GetTrainerName(_trainerExtension);
            SetHyperparamValues(hyperParamSet);
        }
 
        public void SetHyperparamValues(ParameterSet hyperParamSet)
        {
            HyperParamSet = hyperParamSet;
            PropagateParamSetValues();
        }
 
        public SuggestedTrainer Clone()
        {
            return new SuggestedTrainer(_mlContext, _trainerExtension, _columnInfo, HyperParamSet?.Clone());
        }
 
        public ITrainerEstimator<IPredictionTransformer<object>, object> BuildTrainer(IDataView validationSet = null)
        {
            IEnumerable<SweepableParam> sweepParams = null;
            if (HyperParamSet != null)
            {
                sweepParams = SweepParams;
            }
            return _trainerExtension.CreateInstance(_mlContext, sweepParams, _columnInfo, validationSet);
        }
 
        public override string ToString()
        {
            var paramsStr = string.Empty;
            if (SweepParams != null)
            {
                paramsStr = string.Join(", ", SweepParams.Where(p => p != null && p.RawValue != null).Select(p => $"{p.Name}:{p.ProcessedValue()}"));
            }
            return $"{TrainerName}{{{paramsStr}}}";
        }
 
        public PipelineNode ToPipelineNode()
        {
            var sweepParams = SweepParams?.Where(p => p.RawValue != null);
            return _trainerExtension.CreatePipelineNode(sweepParams, _columnInfo);
        }
 
        /// <summary>
        /// make sure sweep params and param set are consistent
        /// </summary>
        private void PropagateParamSetValues()
        {
            if (HyperParamSet == null)
            {
                return;
            }
 
            var spMap = SweepParams.ToDictionary(sp => sp.Name);
 
            foreach (var hp in HyperParamSet)
            {
                if (spMap.ContainsKey(hp.Name))
                {
                    var sp = spMap[hp.Name];
                    sp.SetUsingValueText(hp.ValueText);
                }
            }
        }
    }
}