File: AsyncSweeper.cs
Web Access
Project: src\src\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj (Microsoft.ML.Sweeper)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Sweeper;
 
[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(SweeperBase.OptionsBase), typeof(SignatureAsyncSweeper),
    "Asynchronous Uniform Random Sweeper", "UniformRandomSweeper", "UniformRandom")]
[assembly: LoadableClass(typeof(SimpleAsyncSweeper), typeof(RandomGridSweeper.Options), typeof(SignatureAsyncSweeper),
    "Asynchronous Random Grid Sweeper", "RandomGridSweeper", "RandomGrid")]
[assembly: LoadableClass(typeof(DeterministicSweeperAsync), typeof(DeterministicSweeperAsync.Options), typeof(SignatureAsyncSweeper),
    "Asynchronous and Deterministic Sweeper", "DeterministicSweeper", "Deterministic")]
 
namespace Microsoft.ML.Sweeper
{
    public delegate void SignatureAsyncSweeper();
 
    public sealed class ParameterSetWithId
    {
        public readonly int Id;
        public readonly ParameterSet ParameterSet;
 
        public ParameterSetWithId(int id, ParameterSet param)
        {
            Contracts.CheckParam(id >= 0, nameof(id));
            Contracts.CheckValueOrNull(param);
            ParameterSet = param;
            Id = id;
        }
    }
 
    /// <summary>
    /// An interface for sweeper with asynchornous update and proposal.
    /// </summary>
    public interface IAsyncSweeper
    {
        /// <summary>
        /// Propose a <see cref="ParameterSet"/>.
        /// </summary>
        /// <returns>A future <see cref="ParameterSet"/> and its id. Null if unavailable or cancelled.</returns>
        Task<ParameterSetWithId> ProposeAsync();
 
        /// <summary>
        /// Notify the sweeper of a finished run.
        /// </summary>
        /// <param name="id">Id of the run.</param>
        /// <param name="result">Result of the run. Null if not available.</param>
        void Update(int id, IRunResult result);
 
        /// <summary>
        /// Request the sweeper to stop generating and dispensing new parameters.
        /// </summary>
        void Cancel();
    }
 
    /// <summary>
    /// Expose existing <see cref="ISweeper"/>s as <see cref="IAsyncSweeper"/> with no synchronization over the past runs.
    /// Nelder-Mead requires synchronization so is not compatible with SimpleAsyncSweeperBase.
    /// </summary>
    public partial class SimpleAsyncSweeper : IAsyncSweeper, IDisposable
    {
        private readonly List<IRunResult> _results;
        private readonly object _lock;
        private readonly ISweeper _baseSweeper;
 
        private volatile bool _canceled;
        private bool _disposed;
 
        // The number of ParameterSets generated so far. Used for indexing.
        private int _numGenerated;
 
        private SimpleAsyncSweeper(ISweeper baseSweeper)
        {
            Contracts.CheckValue(baseSweeper, nameof(baseSweeper));
            Contracts.CheckParam(!(baseSweeper is NelderMeadSweeper), nameof(baseSweeper), "baseSweeper cannot be Nelder-Mead");
 
            _baseSweeper = baseSweeper;
            _lock = new object();
            _results = new List<IRunResult>();
        }
 
        public SimpleAsyncSweeper(IHostEnvironment env, UniformRandomSweeper.OptionsBase options)
            : this(new UniformRandomSweeper(env, options))
        {
        }
 
        public SimpleAsyncSweeper(IHostEnvironment env, RandomGridSweeper.Options options)
            : this(new UniformRandomSweeper(env, options))
        {
        }
 
        public void Update(int id, IRunResult result)
        {
            Contracts.CheckParam(0 <= id && id < _numGenerated, nameof(id), "Invalid run id");
            if (!_canceled && result != null)
            {
                lock (_lock)
                    _results.Add(result);
            }
        }
 
        public Task<ParameterSetWithId> ProposeAsync()
        {
            if (_canceled)
                return Task.FromResult<ParameterSetWithId>(null);
            lock (_lock)
            {
                if (_disposed)
                    throw Contracts.Except("Calling Propose after the sweeper is disposed");
                var paramSets = _baseSweeper.ProposeSweeps(1, _results);
                if (Utils.Size(paramSets) > 0)
                    return Task.FromResult(new ParameterSetWithId(_numGenerated++, paramSets[0]));
            }
            return Task.FromResult<ParameterSetWithId>(null);
        }
 
        public void Cancel()
        {
            _canceled = true;
        }
 
        public void Dispose()
        {
            lock (_lock)
            {
                if (!_disposed)
                {
                    (_baseSweeper as IDisposable)?.Dispose();
                    _disposed = true;
                    Cancel();
                }
            }
        }
    }
 
    /// <summary>
    /// An wrapper around <see cref="ISweeper"/> which enforces determinism by imposing synchronization over past runs.
    /// Suppose n <see cref="ParameterSet"/>s are generated up to this point. The sweeper will refrain from making a decision
    /// until the runs with indices in [0, n - relaxation) have all finished. A new batch of <see cref="ParameterSet"/>s will be
    /// generated based on the first n - relaxation runs.
    /// </summary>
    public sealed class DeterministicSweeperAsync : IAsyncSweeper, IDisposable
    {
        public sealed class Options
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Base sweeper", ShortName = "sweeper", SignatureType = typeof(SignatureSweeper))]
            public IComponentFactory<ISweeper> Sweeper;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Sweep batch size", ShortName = "batchsize")]
            public int BatchSize = 5;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Synchronization relaxation", ShortName = "relaxation")]
            public int Relaxation = 0;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Random seed", ShortName = "seed")]
            public int? RandomSeed;
        }
 
        private readonly object _lock;
        private readonly CancellationTokenSource _cts;
 
        private readonly Channel<ParameterSetWithId> _paramChannel;
        private readonly int _relaxation;
        private readonly ISweeper _baseSweeper;
        private readonly IHost _host;
        private readonly int _batchSize;
        private bool _disposed;
 
        // Minimum id of unfinished runs. All runs with id < _minUnfinishedId are finished.
        private int _minUnfinishedId;
 
        // The ith element of _results corresponds to the result of the ith run.
        private readonly List<IRunResult> _results;
 
        // The indices of the runs with null IRunResult. We have to keep track of both the indices and
        // the results of finished runs to determine if the synchronization barrier is satisfied.
        // Using _results alone won't do it as the result could be null.
        // Note that we only need to record those >= _minUnfinishedId.
        private readonly HashSet<int> _nullRuns;
 
        // The number of ParameterSets generated so far. Used for indexing.
        private int _numGenerated;
 
        public DeterministicSweeperAsync(IHostEnvironment env, Options options)
        {
            _host = env.Register("DeterministicSweeperAsync", options.RandomSeed);
            _host.CheckValue(options.Sweeper, nameof(options.Sweeper), "Please specify a sweeper");
            _host.CheckUserArg(options.BatchSize > 0, nameof(options.BatchSize), "Batch size must be positive");
            _host.CheckUserArg(options.Relaxation >= 0, nameof(options.Relaxation), "Synchronization relaxation must be non-negative");
            _host.CheckUserArg(options.Relaxation <= options.BatchSize, nameof(options.Relaxation),
                "Synchronization relaxation cannot be larger than batch size");
            _batchSize = options.BatchSize;
            _baseSweeper = options.Sweeper.CreateComponent(_host);
            _host.CheckUserArg(!(_baseSweeper is NelderMeadSweeper) || options.Relaxation == 0, nameof(options.Relaxation),
                "Nelder-Mead requires full synchronization (relaxation = 0)");
 
            _cts = new CancellationTokenSource();
            _relaxation = options.Relaxation;
            _lock = new object();
            _results = new List<IRunResult>();
            _nullRuns = new HashSet<int>();
            _paramChannel = Channel.CreateUnbounded<ParameterSetWithId>(
                new UnboundedChannelOptions { SingleWriter = true });
 
            PrepareNextBatch(null);
        }
 
        private void PrepareNextBatch(IEnumerable<IRunResult> results)
        {
            _host.Check(!_disposed, "Creating parameters while sweeper is disposed");
            var paramSets = _baseSweeper.ProposeSweeps(_batchSize, results);
            if (Utils.Size(paramSets) == 0)
            {
                // Mark the queue as completed.
                _paramChannel.Writer.Complete();
                return;
            }
            // Assign an id to each ParameterSet and enque it.
            foreach (var paramSet in paramSets)
                _paramChannel.Writer.TryWrite(new ParameterSetWithId(_numGenerated++, paramSet));
            EnsureResultsSize();
        }
 
        private void EnsureResultsSize()
        {
            // Allocate the result slots for the new batch.
            while (_results.Count < _numGenerated)
                _results.Add(null);
        }
 
        public void Update(int id, IRunResult result)
        {
            if (_cts.IsCancellationRequested)
                return;
            _host.Check(0 <= id && id < _results.Count, "Invalid index");
            lock (_lock)
            {
                UpdateResult(id, result);
                UpdateBarrierStatus(id);
                if (CheckBarrier())
                    PrepareNextBatch(_results.GetRange(0, Math.Max(0, _numGenerated - _relaxation)));
            }
        }
 
        private void UpdateResult(int id, IRunResult result)
        {
            if (result == null)
                _nullRuns.Add(id);
            else
                _results[id] = result;
        }
 
        private bool CheckBarrier()
        {
            return _minUnfinishedId >= _numGenerated - _relaxation;
        }
 
        private void UpdateBarrierStatus(int id)
        {
            if (id == _minUnfinishedId)
            {
                while (++_minUnfinishedId < _numGenerated && _results[_minUnfinishedId] != null || _nullRuns.Remove(_minUnfinishedId))
                    ;
            }
        }
 
        public async Task<ParameterSetWithId> ProposeAsync()
        {
            if (_cts.IsCancellationRequested)
                return null;
            try
            {
                return await _paramChannel.Reader.ReadAsync(_cts.Token);
            }
            catch (InvalidOperationException)
            {
                // Do nothing. When the queue is empty and completed, InvalidOperationException will be thrown.
            }
            catch (OperationCanceledException)
            {
                // Nothing to do for canceled tasks.
            }
            return null;
        }
 
        public void Cancel()
        {
            _cts.Cancel();
        }
 
        public void Dispose()
        {
            lock (_lock)
            {
                if (!_disposed)
                {
                    (_baseSweeper as IDisposable)?.Dispose();
                    _disposed = true;
                    Cancel();
                }
            }
        }
    }
}