File: UnitTests\TestHosts.cs
Web Access
Project: src\test\Microsoft.ML.Core.Tests\Microsoft.ML.Core.Tests.csproj (Microsoft.ML.Core.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.RunTests
{
    public class TestHosts : BaseTestClass
    {
        public TestHosts(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        public void TestCancellation()
        {
            IHostEnvironment env = new MLContext(seed: 42);
            for (int z = 0; z < 1000; z++)
            {
                var mainHost = env.Register("Main");
                var children = new ConcurrentDictionary<IHost, List<IHost>>();
                var hosts = new BlockingCollection<Tuple<IHost, int>>();
                hosts.Add(new Tuple<IHost, int>(mainHost.Register("1"), 1));
                hosts.Add(new Tuple<IHost, int>(mainHost.Register("2"), 1));
                hosts.Add(new Tuple<IHost, int>(mainHost.Register("3"), 1));
                hosts.Add(new Tuple<IHost, int>(mainHost.Register("4"), 1));
                hosts.Add(new Tuple<IHost, int>(mainHost.Register("5"), 1));
 
                int iterations = 100;
                Random rand = new Random();
                var addThread = new Thread(
                () =>
                {
                    for (int i = 0; i < iterations; i++)
                    {
                        var randHostTuple = hosts.ElementAt(rand.Next(hosts.Count - 1));
                        var newHost = randHostTuple.Item1.Register((randHostTuple.Item2 + 1).ToString());
                        hosts.Add(new Tuple<IHost, int>(newHost, randHostTuple.Item2 + 1));
                        if (!children.ContainsKey(randHostTuple.Item1))
                            children[randHostTuple.Item1] = new List<IHost>();
                        else
                            children[randHostTuple.Item1].Add(newHost);
                    }
                });
                addThread.Start();
                Queue<IHost> queue = new Queue<IHost>();
                for (int i = 0; i < 5; i++)
                {
                    IHost rootHost = null;
                    var index = 0;
                    do
                    {
                        index = rand.Next(hosts.Count);
                    } while ((hosts.ElementAt(index).Item1 as ICancelable).IsCanceled ||
                              // use 2 instead of 3 here as there is no guarantee there is always level 2 children
                              hosts.ElementAt(index).Item2 < 2);
                    (hosts.ElementAt(index).Item1 as ICancelable).CancelExecution();
                    rootHost = hosts.ElementAt(index).Item1;
                    queue.Enqueue(rootHost);
 
                    // all children has been canceled, we should stop looking
                    if (hosts.Count(q => (q.Item1 as ICancelable).IsCanceled) == hosts.Count - 5)
                    {
                        break;
                    }
                }
                addThread.Join();
                while (queue.Count > 0)
                {
                    var currentHost = queue.Dequeue();
                    Assert.True((currentHost as ICancelable).IsCanceled);
 
                    if (children.ContainsKey(currentHost))
                        children[currentHost].ForEach(x => queue.Enqueue(x));
                }
            }
        }
 
        [Fact]
        public void TestCancellationApi()
        {
            IHostEnvironment env = new MLContext(seed: 42);
            var mainHost = env.Register("Main");
            var children = new ConcurrentDictionary<IHost, List<IHost>>();
            var hosts = new BlockingCollection<Tuple<IHost, int>>();
            hosts.Add(new Tuple<IHost, int>(mainHost.Register("1"), 1));
            hosts.Add(new Tuple<IHost, int>(mainHost.Register("2"), 1));
            hosts.Add(new Tuple<IHost, int>(mainHost.Register("3"), 1));
            hosts.Add(new Tuple<IHost, int>(mainHost.Register("4"), 1));
            hosts.Add(new Tuple<IHost, int>(mainHost.Register("5"), 1));
 
            for (int i = 0; i < 5; i++)
            {
                var tuple = hosts.ElementAt(i);
                var newHost = tuple.Item1.Register((tuple.Item2 + 1).ToString());
                hosts.Add(new Tuple<IHost, int>(newHost, tuple.Item2 + 1));
            }
 
            ((MLContext)env).CancelExecution();
 
            //Ensure all created hosts are canceled.
            //5 parent and one child for each.
            Assert.Equal(10, hosts.Count);
 
            foreach (var host in hosts)
                Assert.True((host.Item1 as ICancelable).IsCanceled);
        }
 
        /// <summary>
        /// Tests that MLContext's Log event intercepts messages properly.
        /// </summary>
        [Fact]
        public void LogEventProcessesMessages()
        {
            var messages = new List<string>();
 
            var env = new MLContext(1);
            env.Log += (sender, e) => messages.Add(e.Message);
 
            // create a dummy text reader to trigger log messages
            env.Data.CreateTextLoader(new TextLoader.Options { Columns = new[] { new TextLoader.Column("TestColumn", DataKind.Single, 0) } });
 
            Assert.True(messages.Count > 0);
        }
    }
}