File: Environment\ConsoleEnvironment.cs
Web Access
Project: src\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj (Microsoft.ML.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using System.Linq;
using System.Threading;
 
namespace Microsoft.ML.Runtime;
 
using Stopwatch = System.Diagnostics.Stopwatch;
 
/// <summary>
/// The console environment. As its name suggests, should be limited to those applications that deliberately want
/// console functionality.
/// </summary>
[BestFriend]
internal sealed class ConsoleEnvironment : HostEnvironmentBase<ConsoleEnvironment>
{
    public const string ComponentHistoryKey = "ComponentHistory";
 
    private sealed class ConsoleWriter
    {
        private readonly object _lock;
        private readonly ConsoleEnvironment _parent;
        private readonly TextWriter _out;
        private readonly TextWriter _err;
        private readonly TextWriter _test;
 
        private readonly bool _colorOut;
        private readonly bool _colorErr;
 
        // Progress reporting. Print up to 50 dots, if there's no meaningful (checkpoint) events.
        // At the end of 50 dots, print current metrics.
        private const int _maxDots = 50;
        private int _dots;
 
        public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter, TextWriter testWriter = null)
        {
            Contracts.AssertValue(parent);
            Contracts.AssertValue(outWriter);
            Contracts.AssertValue(errWriter);
            _lock = new object();
            _parent = parent;
            _out = outWriter;
            _err = errWriter;
            _test = testWriter;
 
            _colorOut = outWriter == Console.Out;
            _colorErr = outWriter == Console.Error;
        }
 
        public void PrintMessage(IMessageSource sender, ChannelMessage msg)
        {
            bool isError = false;
 
            var messageColor = default(ConsoleColor);
 
            switch (msg.Kind)
            {
                case ChannelMessageKind.Trace:
                    if (!sender.Verbose)
                        return;
                    messageColor = ConsoleColor.Gray;
                    break;
                case ChannelMessageKind.Info:
                    break;
                case ChannelMessageKind.Warning:
                    messageColor = ConsoleColor.Yellow;
                    isError = true;
                    break;
                default:
                    Contracts.Assert(msg.Kind == ChannelMessageKind.Error);
                    messageColor = ConsoleColor.Red;
                    isError = true;
                    break;
            }
 
            lock (_lock)
            {
                EnsureNewLine(isError);
                var wr = isError ? _err : _out;
                bool toColor = isError ? _colorOut : _colorErr;
 
                if (toColor && msg.Kind != ChannelMessageKind.Info)
                    Console.ForegroundColor = messageColor;
                string prefix = WriteAndReturnLinePrefix(msg.Sensitivity, wr);
                var commChannel = sender as PipeBase<ChannelMessage>;
                if (commChannel?.Verbose == true)
                {
                    WriteHeader(wr, commChannel);
                    if (_test != null)
                        WriteHeader(_test, commChannel);
                }
                if (msg.Kind == ChannelMessageKind.Warning)
                {
                    wr.Write("Warning: ");
                    _test?.Write("Warning: ");
                }
                _parent.PrintMessageNormalized(wr, msg.Message, true, prefix);
                if (_test != null)
                    _parent.PrintMessageNormalized(_test, msg.Message, true);
                if (toColor)
                    Console.ResetColor();
            }
        }
 
        private string LinePrefix(MessageSensitivity sensitivity)
        {
            if (_parent._sensitivityFlags == MessageSensitivity.All || ((_parent._sensitivityFlags & sensitivity) != MessageSensitivity.None))
                return null;
            return "SystemLog:";
        }
 
        private string WriteAndReturnLinePrefix(MessageSensitivity sensitivity, TextWriter writer)
        {
            string prefix = LinePrefix(sensitivity);
            if (prefix != null)
                writer.Write(prefix);
            return prefix;
        }
 
        private void WriteHeader(TextWriter wr, PipeBase<ChannelMessage> commChannel)
        {
            Contracts.Assert(commChannel.Verbose);
            // REVIEW: Change this to use IndentingTextWriter.
            wr.Write(new string(' ', commChannel.Depth * 2));
            WriteName(wr, commChannel);
        }
 
        private void WriteName(TextWriter wr, ChannelProviderBase provider)
        {
            var channel = provider as Channel;
            if (channel != null)
                WriteName(wr, channel.Parent);
            wr.Write("{0}: ", provider.ShortName);
        }
 
        public void ChannelStarted(Channel channel)
        {
            if (!channel.Verbose)
                return;
 
            lock (_lock)
            {
                EnsureNewLine();
                WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
                WriteHeader(_out, channel);
                _out.WriteLine("Started.");
            }
        }
 
        public void ChannelDisposed(Channel channel)
        {
            if (!channel.Verbose)
                return;
 
            lock (_lock)
            {
                EnsureNewLine();
                WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
                WriteHeader(_out, channel);
                _out.WriteLine("Finished.");
                EnsureNewLine();
                WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
                WriteHeader(_out, channel);
                _out.WriteLine("Elapsed {0:c}.", channel.Watch.Elapsed);
            }
        }
 
        /// <summary>
        /// Query all progress and:
        /// * If there's any checkpoint/start/stop event, print all of them.
        /// * If there's none, print a dot.
        /// * If there's <see cref="_maxDots"/> dots, print the current status for all running calculations.
        /// </summary>
        public void GetAndPrintAllProgress(ProgressReporting.ProgressTracker progressTracker)
        {
            Contracts.AssertValue(progressTracker);
 
            var entries = progressTracker.GetAllProgress();
            if (entries.Count == 0)
            {
                // There's no calculation running. Don't even print a dot.
                return;
            }
 
            var checkpoints = entries.Where(
                x => x.Kind != ProgressReporting.ProgressEvent.EventKind.Progress || x.ProgressEntry.IsCheckpoint);
 
            lock (_lock)
            {
                bool anyCheckpoint = false;
                foreach (var ev in checkpoints)
                {
                    anyCheckpoint = true;
                    EnsureNewLine();
                    // We assume that things like status counters, which contain only things
                    // like loss function values, counts of rows, counts of items, etc., are
                    // not sensitive.
                    WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
                    switch (ev.Kind)
                    {
                        case ProgressReporting.ProgressEvent.EventKind.Start:
                            PrintOperationStart(_out, ev);
                            break;
                        case ProgressReporting.ProgressEvent.EventKind.Stop:
                            PrintOperationStop(_out, ev);
                            break;
                        case ProgressReporting.ProgressEvent.EventKind.Progress:
                            _out.Write("[{0}] ", ev.Index);
                            PrintProgressLine(_out, ev);
                            break;
                    }
                }
                if (anyCheckpoint)
                {
                    // At least one checkpoint has been printed, so there's no need for dots.
                    return;
                }
 
                if (PrintDot())
                {
                    // We need to print an extended status line. At this point, every event should be
                    // a non-checkpoint progress event.
                    bool needPrepend = entries.Count > 1;
                    foreach (var ev in entries)
                    {
                        Contracts.Assert(ev.Kind == ProgressReporting.ProgressEvent.EventKind.Progress);
                        Contracts.Assert(!ev.ProgressEntry.IsCheckpoint);
                        if (needPrepend)
                        {
                            EnsureNewLine();
                            WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
                            _out.Write("[{0}] ", ev.Index);
                        }
                        else
                        {
                            // This is the only case we are printing something at the end of the line of dots.
                            // So, we need to reset the dots counter.
                            _dots = 0;
                        }
                        PrintProgressLine(_out, ev);
                    }
                }
            }
        }
 
        private static void PrintOperationStart(TextWriter writer, ProgressReporting.ProgressEvent ev)
        {
            writer.WriteLine("[{0}] '{1}' started.", ev.Index, ev.Name);
        }
 
        private static void PrintOperationStop(TextWriter writer, ProgressReporting.ProgressEvent ev)
        {
            writer.WriteLine("[{0}] '{1}' finished in {2}.", ev.Index, ev.Name, ev.EventTime - ev.StartTime);
        }
 
        private void PrintProgressLine(TextWriter writer, ProgressReporting.ProgressEvent ev)
        {
            // Elapsed time.
            var elapsed = ev.EventTime - ev.StartTime;
            if (elapsed.TotalMinutes < 1)
                writer.Write("(00:{0:00.00})", elapsed.TotalSeconds);
            else if (elapsed.TotalHours < 1)
                writer.Write("({0:00}:{1:00.0})", elapsed.Minutes, elapsed.TotalSeconds - 60 * elapsed.Minutes);
            else
                writer.Write("({0:00}:{1:00}:{2:00})", elapsed.Hours, elapsed.Minutes, elapsed.Seconds);
 
            // Progress units.
            bool first = true;
            for (int i = 0; i < ev.ProgressEntry.Header.UnitNames.Count; i++)
            {
                if (ev.ProgressEntry.Progress[i] == null)
                    continue;
                writer.Write(first ? "\t" : ", ");
                first = false;
                writer.Write("{0}", ev.ProgressEntry.Progress[i]);
                if (ev.ProgressEntry.ProgressLim[i] != null)
                    writer.Write("/{0}", ev.ProgressEntry.ProgressLim[i].Value);
                writer.Write(" {0}", ev.ProgressEntry.Header.UnitNames[i]);
            }
 
            // Metrics.
            for (int i = 0; i < ev.ProgressEntry.Header.MetricNames.Count; i++)
            {
                if (ev.ProgressEntry.Metrics[i] == null)
                    continue;
                // REVIEW: print metrics prettier.
                writer.Write("\t{0}: {1}", ev.ProgressEntry.Header.MetricNames[i], ev.ProgressEntry.Metrics[i].Value);
            }
 
            writer.WriteLine();
        }
 
        /// <summary>
        /// If we printed any dots so far, finish the line. This call is expected to be protected by _lock.
        /// </summary>
        private void EnsureNewLine(bool isError = false)
        {
            if (_dots == 0)
                return;
 
            // If _err and _out is the same writer, we need to print new line as well.
            // If _out and _err writes to Console.Out and Console.Error respectively,
            // in the general user scenario they ends up with writing to the same underlying stream,.
            // so write a new line to the stream anyways.
            if (isError && _err != _out && (_out != Console.Out || _err != Console.Error))
                return;
 
            _out.WriteLine();
            _dots = 0;
        }
 
        /// <summary>
        /// Print a progress dot. Returns whether it is 'time' to print more info. This call is expected
        /// to be protected by _lock.
        /// </summary>
        private bool PrintDot()
        {
            _out.Write(".");
            _dots++;
            return (_dots == _maxDots);
        }
    }
 
    private sealed class Channel : ChannelBase
    {
        public readonly Stopwatch Watch;
        public Channel(ConsoleEnvironment root, ChannelProviderBase parent, string shortName,
            Action<IMessageSource, ChannelMessage> dispatch)
            : base(root, parent, shortName, dispatch)
        {
            Watch = Stopwatch.StartNew();
            Root._consoleWriter.ChannelStarted(this);
        }
 
        protected override void Dispose(bool disposing)
        {
            if (disposing)
            {
                Watch.Stop();
                Root._consoleWriter.ChannelDisposed(this);
            }
 
            base.Dispose(disposing);
        }
    }
 
    private volatile ConsoleWriter _consoleWriter;
    private readonly MessageSensitivity _sensitivityFlags;
 
    // This object is used to write to the test log along with the console if the host process is a test environment
    private readonly TextWriter _testWriter;
 
    /// <summary>
    /// Create an ML.NET <see cref="IHostEnvironment"/> for local execution, with console feedback.
    /// </summary>
    /// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
    /// <param name="verbose">Set to <c>true</c> for fully verbose logging.</param>
    /// <param name="sensitivity">Allowed message sensitivity.</param>
    /// <param name="outWriter">Text writer to print normal messages to.</param>
    /// <param name="errWriter">Text writer to print error messages to.</param>
    /// <param name="testWriter">Optional TextWriter to write messages if the host is a test environment.</param>
    public ConsoleEnvironment(int? seed = null, bool verbose = false,
        MessageSensitivity sensitivity = MessageSensitivity.All,
        TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
        : base(seed, verbose, nameof(ConsoleEnvironment))
    {
        Contracts.CheckValueOrNull(outWriter);
        Contracts.CheckValueOrNull(errWriter);
        _testWriter = testWriter;
        _consoleWriter = new ConsoleWriter(this, outWriter ?? Console.Out, errWriter ?? Console.Error, testWriter);
        _sensitivityFlags = sensitivity;
        AddListener<ChannelMessage>(PrintMessage);
    }
 
    /// <summary>
    /// Pull running calculations for their progress and output all messages to the console.
    /// If no messages are available, print a dot.
    /// If a specified number of dots are printed, print an ad-hoc status of all running calculations.
    /// </summary>
    public void PrintProgress()
    {
        Root._consoleWriter.GetAndPrintAllProgress(ProgressTracker);
    }
 
    private void PrintMessage(IMessageSource src, ChannelMessage msg)
    {
        Root._consoleWriter.PrintMessage(src, msg);
    }
 
    protected override IHost RegisterCore(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
    {
        Contracts.AssertValue(rand);
        Contracts.AssertValueOrNull(parentFullName);
        Contracts.AssertNonEmpty(shortName);
        Contracts.Assert(source == this || source is Host);
        return new Host(source, shortName, parentFullName, rand, verbose);
    }
 
    protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
    {
        Contracts.AssertValue(parent);
        Contracts.Assert(parent is ConsoleEnvironment);
        Contracts.AssertNonEmpty(name);
        return new Channel(this, parent, name, GetDispatchDelegate<ChannelMessage>());
    }
 
    protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase parent, string name)
    {
        Contracts.AssertValue(parent);
        Contracts.Assert(parent is ConsoleEnvironment);
        Contracts.AssertNonEmpty(name);
        return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
    }
 
    /// <summary>
    /// Redirects the channel output through the specified writers.
    /// </summary>
    /// <remarks>This method is not thread-safe.</remarks>
    internal IDisposable RedirectChannelOutput(TextWriter newOutWriter, TextWriter newErrWriter)
    {
        Contracts.CheckValue(newOutWriter, nameof(newOutWriter));
        Contracts.CheckValue(newErrWriter, nameof(newErrWriter));
        return new OutputRedirector(this, newOutWriter, newErrWriter);
    }
 
    internal void ResetProgressChannel()
    {
        ProgressTracker.Reset();
    }
 
    private sealed class OutputRedirector : IDisposable
    {
        private readonly ConsoleEnvironment _root;
        private ConsoleWriter _oldConsoleWriter;
        private readonly ConsoleWriter _newConsoleWriter;
 
        public OutputRedirector(ConsoleEnvironment env, TextWriter newOutWriter, TextWriter newErrWriter)
        {
            Contracts.AssertValue(env);
            Contracts.AssertValue(newOutWriter);
            Contracts.AssertValue(newErrWriter);
            _root = env.Root;
            _newConsoleWriter = new ConsoleWriter(_root, newOutWriter, newErrWriter, _root._testWriter);
            _oldConsoleWriter = Interlocked.Exchange(ref _root._consoleWriter, _newConsoleWriter);
            Contracts.AssertValue(_oldConsoleWriter);
        }
 
        public void Dispose()
        {
            if (_oldConsoleWriter == null)
                return;
 
            Contracts.Assert(_root._consoleWriter == _newConsoleWriter);
            _root._consoleWriter = _oldConsoleWriter;
            _oldConsoleWriter = null;
        }
    }
 
    private sealed class Host : HostBase
    {
        public Host(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
            : base(source, shortName, parentFullName, rand, verbose)
        {
            IsCanceled = source.IsCanceled;
        }
 
        protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
        {
            Contracts.AssertValue(parent);
            Contracts.Assert(parent is Host);
            Contracts.AssertNonEmpty(name);
            return new Channel(Root, parent, name, GetDispatchDelegate<ChannelMessage>());
        }
 
        protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase parent, string name)
        {
            Contracts.AssertValue(parent);
            Contracts.Assert(parent is Host);
            Contracts.AssertNonEmpty(name);
            return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
        }
 
        protected override IHost RegisterCore(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
        {
            return new Host(source, shortName, parentFullName, rand, verbose);
        }
    }
}