|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
#nullable enable
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.AutoML
{
internal class SweepablePipelineRunner : ITrialRunner
{
private MLContext? _mLContext;
private readonly IEvaluateMetricManager _metricManager;
private readonly IDatasetManager _datasetManager;
private readonly SweepablePipeline _pipeline;
private readonly IChannel? _logger;
public SweepablePipelineRunner(MLContext context, SweepablePipeline pipeline, IEvaluateMetricManager metricManager, IDatasetManager datasetManager, IChannel? logger = null)
{
_mLContext = context;
_metricManager = metricManager;
_pipeline = pipeline;
_datasetManager = datasetManager;
_logger = logger;
}
public TrialResult Run(TrialSettings settings)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
var mlnetPipeline = _pipeline.BuildFromOption(_mLContext, parameter);
if (_datasetManager is ICrossValidateDatasetManager crossValidateDatasetManager)
{
var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold, crossValidateDatasetManager.SamplingKeyColumnName);
var metrics = new List<double>();
var models = new List<ITransformer>();
foreach (var split in datasetSplit)
{
// a work-around to fix issue https://github.com/dotnet/machinelearning-modelbuilder/issues/2718
// where the root cause is the shape of deep learning model is determined by the first time when this model is trained
// therefore, the deep learning model can't be retrained using the same pipeline
mlnetPipeline = _pipeline.BuildFromOption(_mLContext, parameter);
var model = mlnetPipeline.Fit(split.TrainSet);
var eval = model.Transform(split.TestSet);
metrics.Add(_metricManager.Evaluate(_mLContext, eval));
models.Add(model);
}
stopWatch.Stop();
var metric = metrics.Average();
var loss = _metricManager.IsMaximize ? -metric : metric;
return new TrialResult
{
Metric = metric,
Model = models.First(),
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
TrialSettings = settings,
Loss = loss,
};
}
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var model = mlnetPipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_mLContext!, settings));
var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_mLContext!, settings));
var metric = _metricManager.Evaluate(_mLContext, eval);
stopWatch.Stop();
var loss = _metricManager.IsMaximize ? -metric : metric;
return new TrialResult
{
Loss = loss,
Metric = metric,
Model = model,
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
TrialSettings = settings,
};
}
throw new ArgumentException("IDatasetManager must be either ITrainTestDatasetManager or ICrossValidationDatasetManager");
}
public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
{
try
{
using (var ctRegistration = ct.Register(() =>
{
_mLContext?.CancelExecution();
}))
{
return Task.FromResult(Run(settings));
}
}
catch (Exception ex) when (ct.IsCancellationRequested)
{
throw new OperationCanceledException(ex.Message, ex.InnerException);
}
catch (Exception)
{
throw;
}
}
public void Dispose()
{
_mLContext!.CancelExecution();
_mLContext = null;
}
}
}
|