File: AutoMLExperiment\IMonitor.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.AutoML
{
    /// <summary>
    /// instance for monitor, which is used by <see cref="AutoMLExperiment"/> to report training progress.
    /// </summary>
    public interface IMonitor
    {
        void ReportCompletedTrial(TrialResult result);
 
        void ReportBestTrial(TrialResult result);
 
        void ReportFailTrial(TrialSettings settings, Exception exception = null);
 
        void ReportRunningTrial(TrialSettings settings);
    }
 
    /// <summary>
    /// monitor that redirect output results to context.Log channel.
    /// </summary>
    internal class MLContextMonitor : IMonitor
    {
        private readonly IChannel _logger;
        private readonly List<TrialResult> _completedTrials;
        private readonly SweepablePipeline _pipeline;
 
        public MLContextMonitor(IChannel logger, SweepablePipeline pipeline)
        {
            _logger = logger;
            _completedTrials = new List<TrialResult>();
            _pipeline = pipeline;
        }
 
        public virtual void ReportBestTrial(TrialResult result)
        {
            _logger.Info($"Update Best Trial - Id: {result.TrialSettings.TrialId} - Metric: {result.Metric} - Pipeline: {_pipeline.ToString(result.TrialSettings.Parameter)}");
        }
 
        public virtual void ReportCompletedTrial(TrialResult result)
        {
            if (result is FairnessTrialResult fResult)
            {   //TODO: now we are assuming the higher the raw metric the better and the lower the fairness metric the better. If we have a raw metric that needs to be minimized then this should change
                _logger.Info($"Update Completed Trial - Id: {result.TrialSettings.TrialId} - Raw Metric: {result.Metric + fResult.FairnessMetric} - Fairness Metric: {-fResult.FairnessMetric} - Total Metric: {result.Metric} - Pipeline: {this._pipeline} - Duration: {result.DurationInMilliseconds}");
            }
            else
            {
                _logger.Info($"Update Completed Trial - Id: {result.TrialSettings.TrialId} - Metric: {result.Metric} - Pipeline: {this._pipeline} - Duration: {result.DurationInMilliseconds}");
            }
            _completedTrials.Add(result);
        }
 
        public virtual void ReportFailTrial(TrialSettings settings, Exception exception = null)
        {
            _logger.Trace(exception.Message + exception.StackTrace);
            _logger.Info($"Update Failed Trial - Id: {settings.TrialId} - Pipeline: {_pipeline.ToString(settings.Parameter)}");
        }
 
        public virtual void ReportRunningTrial(TrialSettings setting)
        {
            _logger.Info($"Update Running Trial - Id: {setting.TrialId}");
        }
    }
 
    internal class TrialResultMonitor<TMetrics> : MLContextMonitor
        where TMetrics : class
    {
        public TrialResultMonitor(IChannel channel, SweepablePipeline pipeline)
            : base(channel, pipeline)
        {
            this.RunDetails = new List<TrialResult<TMetrics>>();
        }
 
        public event EventHandler<TrialResult<TMetrics>> OnTrialCompleted;
 
        public List<TrialResult<TMetrics>> RunDetails { get; }
 
        public TrialResult<TMetrics> BestRun { get; private set; }
 
        public override void ReportBestTrial(TrialResult result)
        {
            base.ReportBestTrial(result);
            if (result is TrialResult<TMetrics> binaryClassificationResult)
            {
                BestRun = binaryClassificationResult;
            }
            else
            {
                throw new ArgumentException($"result must be of type {typeof(TrialResult<TMetrics>)}");
            }
        }
 
        public override void ReportCompletedTrial(TrialResult result)
        {
            base.ReportCompletedTrial(result);
            if (result is TrialResult<TMetrics> metricResult)
            {
                RunDetails.Add(metricResult);
                OnTrialCompleted?.Invoke(this, metricResult);
            }
            else
            {
                throw new ArgumentException($"result must be of type {typeof(TrialResult<TMetrics>)}");
            }
        }
 
        public override void ReportFailTrial(TrialSettings settings, Exception exp)
        {
            base.ReportFailTrial(settings, exp);
 
            var result = new TrialResult<TMetrics>
            {
                TrialSettings = settings,
                Exception = exp,
            };
 
            RunDetails.Add(result);
        }
    }
}