File: Utilities\LocalEnvironment.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    using Stopwatch = System.Diagnostics.Stopwatch;
 
    /// <summary>
    /// An ML.NET environment for local execution.
    /// </summary>
    internal sealed class LocalEnvironment : HostEnvironmentBase<LocalEnvironment>
    {
        private sealed class Channel : ChannelBase
        {
            public readonly Stopwatch Watch;
            public Channel(LocalEnvironment root, ChannelProviderBase parent, string shortName,
                Action<IMessageSource, ChannelMessage> dispatch)
                : base(root, parent, shortName, dispatch)
            {
                Watch = Stopwatch.StartNew();
                Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, MessageSensitivity.None, "Channel started"));
            }
 
            private void ChannelFinished()
                => Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, MessageSensitivity.None, "Channel finished. Elapsed {0:c}.", Watch.Elapsed));
 
            protected override void Dispose(bool disposing)
            {
                if (disposing)
                {
                    ChannelFinished();
                    Watch.Stop();
 
                    Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, MessageSensitivity.None, "Channel disposed"));
                }
 
                base.Dispose(disposing);
            }
        }
 
        /// <summary>
        /// Create an ML.NET <see cref="IHostEnvironment"/> for local execution.
        /// </summary>
        /// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
        public LocalEnvironment(int? seed = null)
            : base(seed, verbose: false)
        {
        }
 
        /// <summary>
        /// Add a custom listener to the messages of ML.NET components.
        /// </summary>
        public void AddListener(Action<IMessageSource, ChannelMessage> listener)
            => AddListener<ChannelMessage>(listener);
 
        /// <summary>
        /// Remove a previously added a custom listener.
        /// </summary>
        public void RemoveListener(Action<IMessageSource, ChannelMessage> listener)
            => RemoveListener<ChannelMessage>(listener);
 
        protected override IHost RegisterCore(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
        {
            Contracts.AssertValue(rand);
            Contracts.AssertValueOrNull(parentFullName);
            Contracts.AssertNonEmpty(shortName);
            Contracts.Assert(source == this || source is Host);
            return new Host(source, shortName, parentFullName, rand, verbose);
        }
 
        protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
        {
            Contracts.AssertValue(parent);
            Contracts.Assert(parent is LocalEnvironment);
            Contracts.AssertNonEmpty(name);
            return new Channel(this, parent, name, GetDispatchDelegate<ChannelMessage>());
        }
 
        protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase parent, string name)
        {
            Contracts.AssertValue(parent);
            Contracts.Assert(parent is LocalEnvironment);
            Contracts.AssertNonEmpty(name);
            return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
        }
 
        private sealed class Host : HostBase
        {
            public Host(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
                : base(source, shortName, parentFullName, rand, verbose)
            {
                IsCanceled = source.IsCanceled;
            }
 
            protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
            {
                Contracts.AssertValue(parent);
                Contracts.Assert(parent is Host);
                Contracts.AssertNonEmpty(name);
                return new Channel(Root, parent, name, GetDispatchDelegate<ChannelMessage>());
            }
 
            protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase parent, string name)
            {
                Contracts.AssertValue(parent);
                Contracts.Assert(parent is Host);
                Contracts.AssertNonEmpty(name);
                return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
            }
 
            protected override IHost RegisterCore(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, Random rand, bool verbose)
            {
                return new Host(source, shortName, parentFullName, rand, verbose);
            }
        }
    }
 
}