File: Data\ProgressReporter.cs
Web Access
Project: src\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj (Microsoft.ML.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Microsoft.ML.Internal.Utilities;
 
namespace Microsoft.ML.Runtime;
 
/// <summary>
/// The progress reporting classes used by <see cref="HostEnvironmentBase{THostEnvironmentBase}"/> descendants.
/// </summary>
[BestFriend]
internal static class ProgressReporting
{
    /// <summary>
    /// The progress channel for <see cref="ConsoleEnvironment"/>.
    /// This is coupled with a <see cref="ProgressTracker"/> that aggregates all events and returns them on demand.
    /// </summary>
    public sealed class ProgressChannel : IProgressChannel
    {
        private readonly IExceptionContext _ectx;
 
        /// <summary>
        /// The pair of (header, fill action) is updated atomically.
        /// </summary>
        private Tuple<ProgressHeader, Action<IProgressEntry>> _headerAndAction;
 
        /// <summary>
        /// Normally this should be readonly field, but we want to null it in Dispose to prevent memory leaking.
        /// </summary>
        private ProgressTracker _tracker;
 
        private readonly ConcurrentDictionary<int, SubChannel> _subChannels;
        private volatile int _maxSubId;
        private bool _isDisposed;
 
        public string Name { get; }
 
        /// <summary>
        /// Initialize a <see cref="ProgressChannel"/> for the process identified by <paramref name="computationName"/>.
        /// </summary>
        /// <param name="ectx">The exception context.</param>
        /// <param name="tracker">The tracker to couple with.</param>
        /// <param name="computationName">The computation name.</param>
        public ProgressChannel(IExceptionContext ectx, ProgressTracker tracker, string computationName)
        {
            Contracts.CheckValueOrNull(ectx);
            _ectx = ectx;
            _ectx.CheckValue(tracker, nameof(tracker));
            _ectx.CheckNonEmpty(computationName, nameof(computationName));
 
            Name = computationName;
            _tracker = tracker;
            _subChannels = new ConcurrentDictionary<int, SubChannel>();
            _maxSubId = 0;
 
            _headerAndAction = Tuple.Create<ProgressHeader, Action<IProgressEntry>>(new ProgressHeader(null), null);
            Start();
        }
 
        public void SetHeader(ProgressHeader header, Action<IProgressEntry> fillAction)
        {
            _headerAndAction = Tuple.Create(header, fillAction);
        }
 
        public void Checkpoint(params double?[] values)
        {
            _ectx.AssertValueOrNull(values);
            _ectx.Check(!_isDisposed, "Can't report checkpoints after disposing");
            var entry = new ProgressEntry(true, _headerAndAction.Item1);
 
            int n = Utils.Size(values);
            int iSrc = 0;
 
            for (int iDst = 0; iDst < entry.Metrics.Length && iSrc < n;)
                entry.Metrics[iDst++] = values[iSrc++];
 
            for (int iDst = 0; iDst < entry.Progress.Length && iSrc < n;)
                entry.Progress[iDst++] = values[iSrc++];
 
            for (int iDst = 0; iDst < entry.ProgressLim.Length && iSrc < n;)
            {
                var lim = values[iSrc++];
                if (Double.IsNaN(lim.GetValueOrDefault()))
                    lim = null;
                entry.ProgressLim[iDst++] = lim;
            }
 
            _ectx.Check(iSrc == n, "Too many values provided in Checkpoint");
            _tracker.Log(this, ProgressEvent.EventKind.Progress, entry);
        }
 
        private void Start()
        {
            _tracker.Log(this, ProgressEvent.EventKind.Start, null);
        }
 
        private void Stop()
        {
            _tracker.Log(this, ProgressEvent.EventKind.Stop, null);
        }
 
        public void Dispose()
        {
            if (_isDisposed)
                return;
            _isDisposed = true;
            Stop();
 
            Contracts.Assert(_subChannels.Count == 0);
            // The 'get progress' action could potentially reference additional objects via closures.
            // This constitutes a memory leak potential, if the progress tracker object is retained for longer than the operation was running.
            _headerAndAction = null;
            _tracker = null;
        }
 
        /// <summary>
        /// Pull the current progress by invoking the fill delegate, if any.
        /// </summary>
        public ProgressEntry GetProgress()
        {
            // Make sure we get header and action from the same pair, even if outdated.
            var cache = _headerAndAction;
            var fillAction = cache.Item2;
            var entry = new ProgressEntry(false, cache.Item1);
 
            if (fillAction == null)
                Contracts.Assert(entry.Header.MetricNames.Count == 0 && entry.Header.UnitNames.Count == 0);
            else
                fillAction(entry);
 
            return BuildJointEntry(entry);
        }
 
        public IProgressChannel StartProgressChannel(string name)
        {
            return StartProgressChannel(1);
        }
 
        private IProgressChannel StartProgressChannel(int level)
        {
            var newId = Interlocked.Increment(ref _maxSubId);
            return new SubChannel(this, level, newId);
        }
 
        private void SubChannelStopped(int id)
        {
            SubChannel channel;
            _subChannels.TryRemove(id, out channel);
            // Duplicate removal is OK, so we don't inspect return value.
        }
 
        private void SubChannelStarted(int id, SubChannel channel)
        {
            var res = _subChannels.GetOrAdd(id, channel);
            Contracts.Assert(res == channel);
        }
 
        private ProgressEntry BuildJointEntry(ProgressEntry rootEntry)
        {
            if (_maxSubId == 0 || _subChannels.Count == 0)
                return rootEntry;
 
            // REVIEW: consider caching the headers, in case the sub-reporters haven't changed.
            // This is not anticipated to be a perf-critical path though.
            var hProgress = new List<string>();
            var hMetrics = new List<string>();
            var progress = new List<double?>();
            var progressLim = new List<double?>();
            var metrics = new List<double?>();
 
            hProgress.AddRange(rootEntry.Header.UnitNames);
            hMetrics.AddRange(rootEntry.Header.MetricNames);
            progress.AddRange(rootEntry.Progress);
            progressLim.AddRange(rootEntry.ProgressLim);
            metrics.AddRange(rootEntry.Metrics);
 
            foreach (var subChannel in _subChannels.Values.ToArray().OrderBy(x => x.Level))
            {
                var entry = subChannel.GetProgress();
                hProgress.AddRange(entry.Header.UnitNames);
                hMetrics.AddRange(entry.Header.MetricNames);
                progress.AddRange(entry.Progress);
                progressLim.AddRange(entry.ProgressLim);
                metrics.AddRange(entry.Metrics);
            }
 
            var jointEntry = new ProgressEntry(false, new ProgressHeader(hMetrics.ToArray(), hProgress.ToArray()));
            progress.CopyTo(jointEntry.Progress);
            progressLim.CopyTo(jointEntry.ProgressLim);
            metrics.CopyTo(jointEntry.Metrics);
            return jointEntry;
        }
 
        /// <summary>
        /// This is a 'derived' or 'subordinate' progress channel.
        ///
        /// The subordinates' Start/Stop events and checkpoints will not be propagated.
        /// When the status is requested, all of the subordinate channels are also invoked,
        /// and the resulting metrics are then returned in the order of their 'subordinate level'.
        /// If there's more than one channel with the same level, the order is not defined.
        /// </summary>
        private sealed class SubChannel : IProgressChannel
        {
            private readonly ProgressChannel _root;
            private readonly int _id;
            // The 'depth' of subordinate.
            private readonly int _level;
 
            /// <summary>
            /// The pair of (header, fill action) is updated atomically.
            /// </summary>
            private Tuple<ProgressHeader, Action<IProgressEntry>> _headerAndAction;
 
            public int Level { get { return _level; } }
 
            /// <summary>
            /// Pull the current progress by invoking the fill delegate, if any.
            /// </summary>
            public ProgressEntry GetProgress()
            {
                // Make sure we get header and action from the same pair, even if outdated.
                var cache = _headerAndAction;
                var fillAction = cache.Item2;
                var entry = new ProgressEntry(false, cache.Item1);
 
                if (fillAction == null)
                    Contracts.Assert(entry.Header.MetricNames.Count == 0 && entry.Header.UnitNames.Count == 0);
                else
                    fillAction(entry);
                return entry;
            }
 
            public SubChannel(ProgressChannel root, int id, int level)
            {
                Contracts.AssertValue(root);
                Contracts.Assert(level >= 0);
                _root = root;
                _id = id;
                _level = level;
                _headerAndAction = Tuple.Create<ProgressHeader, Action<IProgressEntry>>(new ProgressHeader(null), null);
                Start();
            }
 
            public IProgressChannel StartProgressChannel(string name)
            {
                return _root.StartProgressChannel(_level + 1);
            }
 
            public void Dispose()
            {
                Stop();
            }
 
            public void SetHeader(ProgressHeader header, Action<IProgressEntry> fillAction)
            {
                _headerAndAction = Tuple.Create(header, fillAction);
            }
 
            private void Start()
            {
                _root.SubChannelStarted(_id, this);
            }
 
            private void Stop()
            {
                _root.SubChannelStopped(_id);
            }
 
            public void Checkpoint(params Double?[] values)
            {
                // We are ignoring all checkpoints from subordinates.
                // REVIEW: maybe this could be changed in the future. Right now it seems that
                // this limitation is reasonable.
            }
        }
    }
 
    /// <summary>
    /// This class listens to the progress reporting channels, caches all checkpoints and
    /// start/stop events and, on demand, requests current progress on all active calculations.
    ///
    /// The public methods of this class should only be called from one thread.
    /// </summary>
    public sealed class ProgressTracker
    {
        private readonly IExceptionContext _ectx;
        private readonly object _lock;
 
        /// <summary>
        /// Log of pending events.
        /// </summary>
        private readonly ConcurrentQueue<ProgressEvent> _pendingEvents;
 
        /// <summary>
        /// For each calculation, its properties.
        /// This list is protected by <see cref="_lock"/>, and it's updated every time a new calculation starts.
        /// The entries are cleaned up when the start and stop events are reported (that is, after the first
        /// pull request after the calculation's 'Stop' event).
        /// </summary>
        private readonly List<CalculationInfo> _infos;
 
        /// <summary>
        /// This is a 'process index' that gets incremented whenever a new calculation is started.
        /// </summary>
        private int _index;
 
        /// <summary>
        /// The set of used process names.
        /// </summary>
        private readonly HashSet<string> _namesUsed;
 
        /// <summary>
        /// This class is an 'event log' for one calculation.
        ///
        /// Every time a calculation is 'started', it gets its own log, so if there are multiple 'start' calls,
        /// there will be multiple logs.
        /// </summary>
        private sealed class CalculationInfo
        {
            /// <summary>
            /// Auto-assigned index to serve as a unique ID.
            /// </summary>
            public readonly int Index;
 
            /// <summary>
            /// Name is auto-modified from the calculation name provided by the pipe.
            /// </summary>
            public readonly string Name;
 
            public readonly DateTime StartTime;
 
            public readonly ProgressChannel Channel;
 
            /// <summary>
            /// A log of pending checkpoint entries.
            /// </summary>
            public readonly ConcurrentQueue<KeyValuePair<DateTime, ProgressEntry>> PendingCheckpoints;
 
            /// <summary>
            /// Whether the calculation has finished.
            /// </summary>
            public bool IsFinished;
 
            public CalculationInfo(int index, string name, ProgressChannel channel)
            {
                Contracts.Assert(index > 0);
                Contracts.AssertNonEmpty(name);
                Contracts.AssertValue(channel);
 
                Index = index;
                Name = name;
                PendingCheckpoints = new ConcurrentQueue<KeyValuePair<DateTime, ProgressEntry>>();
                StartTime = DateTime.UtcNow;
                Channel = channel;
            }
        }
 
        public ProgressTracker(IExceptionContext ectx)
        {
            Contracts.CheckValue(ectx, nameof(ectx));
            _ectx = ectx;
            _lock = new object();
            _pendingEvents = new ConcurrentQueue<ProgressEvent>();
            _infos = new List<CalculationInfo>();
            _namesUsed = new HashSet<string>();
        }
 
        public void Log(ProgressChannel source, ProgressEvent.EventKind kind, ProgressEntry entry)
        {
            _ectx.AssertValue(source);
            _ectx.AssertValueOrNull(entry);
 
            if (kind == ProgressEvent.EventKind.Start)
            {
                _ectx.Assert(entry == null);
                lock (_lock)
                {
                    // Figure out an appropriate name.
                    int i = 1;
                    var name = source.Name;
                    string nameCandidate = name;
                    while (!_namesUsed.Add(nameCandidate))
                    {
                        i++;
                        nameCandidate = string.Format("{0} #{1}", name, i);
                    }
                    var newInfo = new CalculationInfo(++_index, nameCandidate, source);
                    _infos.Add(newInfo);
                    _pendingEvents.Enqueue(new ProgressEvent(newInfo.Index, newInfo.Name, newInfo.StartTime, ProgressEvent.EventKind.Start));
                    return;
                }
            }
 
            // Not a start event, so we won't modify the _infos.
            CalculationInfo info;
            lock (_lock)
            {
                info = _infos.FirstOrDefault(x => x.Channel == source);
                if (info == null)
                    throw _ectx.Except("Event sent after the calculation lifetime expired.");
            }
            switch (kind)
            {
                case ProgressEvent.EventKind.Stop:
                    _ectx.Assert(entry == null);
                    info.IsFinished = true;
                    _pendingEvents.Enqueue(new ProgressEvent(info.Index, info.Name, info.StartTime, ProgressEvent.EventKind.Stop));
                    break;
                default:
                    _ectx.Assert(entry != null);
                    _ectx.Assert(kind == ProgressEvent.EventKind.Progress);
                    _ectx.Assert(!info.IsFinished);
                    _pendingEvents.Enqueue(new ProgressEvent(info.Index, info.Name, info.StartTime, entry));
                    break;
            }
        }
 
        /// <summary>
        /// Get progress reports from all current calculations.
        /// For every calculation the following events will be returned:
        /// * A start event.
        /// * Each checkpoint.
        /// * If the calculation is finished, the stop event.
        ///
        /// Each of the above events will be returned exactly once.
        /// If, for one calculation, there's no events in the above categories, the tracker will
        /// request ('pull') the current progress and return this as an event.
        /// </summary>
        public List<ProgressEvent> GetAllProgress()
        {
            var list = new List<ProgressEvent>();
            var seen = new HashSet<int>();
            ProgressEvent cur;
            while (_pendingEvents.TryDequeue(out cur))
            {
                seen.Add(cur.Index);
                list.Add(cur);
            }
 
            // Get unseen calculations to pull progress from.
            CalculationInfo[] unseen;
            lock (_lock)
            {
                unseen = _infos.Where(x => !seen.Contains(x.Index)).ToArray();
                _infos.RemoveAll(x => x.IsFinished);
            }
 
            foreach (var info in unseen)
            {
                // The calculation might finish while we're inside the GetAllProgress. We will report the finish
                // event in the next status, but we make a half-hearted effort not to call the delegate on a finished
                // calculation.
                if (info.IsFinished)
                    continue;
 
                var entry = info.Channel.GetProgress();
                list.Add(new ProgressEvent(info.Index, info.Name, info.StartTime, entry));
            }
 
            return list;
        }
 
        public void Reset()
        {
            lock (_lock)
            {
                while (!_pendingEvents.IsEmpty)
                    _pendingEvents.TryDequeue(out var res);
                _namesUsed.Clear();
                _index = 0;
            }
        }
    }
 
    /// <summary>
    /// An array-backed implementation of <see cref="IProgressEntry"/>.
    /// </summary>
    public sealed class ProgressEntry : IProgressEntry
    {
        /// <summary>
        /// The header (names of metrics and units).
        /// The contents of the header should be treated as read-only. The calculation itself doesn't even
        /// need to access the header, since it will know it anyway.
        /// </summary>
        public readonly ProgressHeader Header;
 
        /// <summary>
        /// Whether the progress entry is a 'checkpoint' (that is, it's being pushed by the component).
        /// </summary>
        public readonly bool IsCheckpoint;
 
        /// <summary>
        /// The actual progress (amount of completed units), in the units that are contained in the header.
        /// Parallel to the header's <see cref="ProgressHeader.UnitNames"/>. Null value indicates 'not applicable now'.
        ///
        /// The computation should not modify these arrays directly, and instead rely on <see cref="SetMetric"/>,
        /// <see cref="SetProgress(int,double)"/> and <see cref="SetProgress(int,double,double)"/>.
        /// </summary>
        public readonly Double?[] Progress;
 
        /// <summary>
        /// The lim values of each progress unit.
        /// Parallel to the header's <see cref="ProgressHeader.UnitNames"/>. Null value indicates unbounded or unknown.
        /// </summary>
        public readonly Double?[] ProgressLim;
 
        /// <summary>
        /// The reported metrics. Parallel to the header's <see cref="ProgressHeader.MetricNames"/>.
        /// Null value indicates unknown.
        /// </summary>
        public readonly Double?[] Metrics;
 
        /// <summary>
        /// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
        /// and the limit value for the progress becomes 'unknown'.
        /// </summary>
        public void SetProgress(int index, Double value)
        {
            Contracts.Check(0 <= index && index < Progress.Length);
            Progress[index] = value;
            ProgressLim[index] = null;
        }
 
        /// <summary>
        /// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
        /// and the limit value to <paramref name="lim"/>.
        /// </summary>
        public void SetProgress(int index, Double value, Double lim)
        {
            Contracts.Check(0 <= index && index < Progress.Length);
            Contracts.Assert(0 <= index && index < Progress.Length);
            Progress[index] = value;
            ProgressLim[index] = Double.IsNaN(lim) ? (Double?)null : lim;
        }
 
        /// <summary>
        /// Sets the metric with index <paramref name="index"/> to <paramref name="value"/>.
        /// </summary>
        public void SetMetric(int index, Double value)
        {
            Contracts.Check(0 <= index && index < Metrics.Length);
            Metrics[index] = value;
        }
 
        /// <summary>
        /// Creates the progress entry corresponding to a given header.
        /// </summary>
        public ProgressEntry(bool isCheckpoint, ProgressHeader header)
        {
            Contracts.CheckValue(header, nameof(header));
            Header = header;
            IsCheckpoint = isCheckpoint;
            Progress = new Double?[header.UnitNames.Count];
            ProgressLim = new Double?[header.UnitNames.Count];
            Metrics = new Double?[header.MetricNames.Count];
        }
    }
 
    /// <summary>
    /// An event about calculation progress. It could be either start/stop of the calculation, or a progress entry.
    /// </summary>
    public sealed class ProgressEvent
    {
        // REVIEW: Separate kind for checkpoint?
        public enum EventKind
        {
            Start,
            Progress,
            Stop
        }
 
        public readonly int Index;
        public readonly string Name;
        // REVIEW: Maybe switch to the stopwatch-based wall clock?
        public readonly DateTime StartTime;
        public readonly DateTime EventTime;
        public readonly EventKind Kind;
        public readonly ProgressEntry ProgressEntry;
 
        public ProgressEvent(int index, string name, DateTime startTime, ProgressEntry entry)
        {
            Contracts.CheckParam(index >= 0, nameof(index));
            Contracts.CheckNonEmpty(name, nameof(name));
            Contracts.CheckValue(entry, nameof(entry));
 
            Index = index;
            Name = name;
            StartTime = startTime;
            EventTime = DateTime.UtcNow;
            Kind = EventKind.Progress;
            ProgressEntry = entry;
        }
 
        public ProgressEvent(int index, string name, DateTime startTime, EventKind kind)
        {
            Contracts.CheckParam(index >= 0, nameof(index));
            Contracts.CheckNonEmpty(name, nameof(name));
 
            Index = index;
            Name = name;
            StartTime = startTime;
            EventTime = DateTime.UtcNow;
            Kind = kind;
            ProgressEntry = null;
        }
    }
}