File: AutoMLExperiment\IStopTrainingManager.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using Microsoft.ML.Runtime;
 
#nullable enable
namespace Microsoft.ML.AutoML
{
    internal interface IStopTrainingManager
    {
        bool IsStopTrainingRequested();
 
        void Update(TrialResult result);
 
        public event EventHandler OnStopTraining;
    }
 
    internal class CancellationTokenStopTrainingManager : IStopTrainingManager
    {
        private readonly CancellationToken _token;
        private readonly IChannel? _channel;
        public event EventHandler? OnStopTraining;
 
        public CancellationTokenStopTrainingManager(CancellationToken ct, IChannel? channel)
        {
            _token = ct;
            _channel = channel;
            ct.Register(() =>
            {
                _channel?.Info("cancel training because cancellation token is invoked...");
                OnStopTraining?.Invoke(this, EventArgs.Empty);
            });
        }
 
        public bool IsStopTrainingRequested()
        {
            if (_token.IsCancellationRequested)
            {
                return true;
            }
 
            return false;
        }
 
        public void Update(TrialResult result)
        {
            return;
        }
    }
 
    internal class TimeoutTrainingStopManager : IStopTrainingManager
    {
        private readonly CancellationTokenStopTrainingManager _cancellationTokenTrainingStopManager;
        private readonly CancellationTokenSource _cts;
 
        public event EventHandler? OnStopTraining;
 
        public TimeoutTrainingStopManager(TimeSpan timeoutInSeconds, IChannel? channel)
        {
            _cts = new CancellationTokenSource();
            _cts.CancelAfter(timeoutInSeconds);
            _cancellationTokenTrainingStopManager = new CancellationTokenStopTrainingManager(_cts.Token, channel);
            _cancellationTokenTrainingStopManager.OnStopTraining += (o, e) =>
            {
                OnStopTraining?.Invoke(this, e);
            };
        }
 
        public bool IsStopTrainingRequested()
        {
            return _cancellationTokenTrainingStopManager.IsStopTrainingRequested();
        }
 
        public void Update(TrialResult result)
        {
            return;
        }
    }
 
    internal class MaxModelStopManager : IStopTrainingManager
    {
        private readonly int _maxModel;
        private int _exploredModel = 0;
        public event EventHandler? OnStopTraining;
 
        public MaxModelStopManager(int maxModel, IChannel? channel)
        {
            _maxModel = maxModel;
        }
 
        public bool IsStopTrainingRequested()
        {
            return _exploredModel >= _maxModel;
        }
 
        public void Update(TrialResult result)
        {
            _exploredModel++;
            if (_exploredModel > _maxModel)
            {
                OnStopTraining?.Invoke(this, EventArgs.Empty);
            }
        }
    }
 
    /// <summary>
    /// stop training when any of child training stop manager is stopped.
    /// </summary>
    internal class AggregateTrainingStopManager : IStopTrainingManager
    {
        private readonly List<IStopTrainingManager> _managers;
 
        public event EventHandler? OnStopTraining;
 
        public AggregateTrainingStopManager(IChannel? channel, params IStopTrainingManager[] managers)
        {
            _managers = managers.ToList();
            foreach (var manager in _managers)
            {
                manager.OnStopTraining += (o, e) =>
                {
                    OnStopTraining?.Invoke(this, e);
                };
            }
        }
 
        public bool IsStopTrainingRequested()
        {
            return _managers.Any(m => m.IsStopTrainingRequested());
        }
 
        public void AddTrainingStopManager(IStopTrainingManager manager)
        {
            _managers.Add(manager);
            manager.OnStopTraining += (o, e) =>
            {
                if (_managers.Exists(manager.Equals))
                {
                    OnStopTraining?.Invoke(this, e);
                }
            };
        }
 
        public void RemoveTrainingStopManagerIfExist(IStopTrainingManager manager)
        {
            if (_managers.Exists(manager.Equals))
            {
                _managers.RemoveAll(manager.Equals);
            }
        }
 
        public void Update(TrialResult result)
        {
            foreach (var manager in _managers)
            {
                manager.Update(result);
            }
        }
    }
}