File: StopTrainingManagerTests.cs
Web Access
Project: src\test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj (Microsoft.ML.AutoML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.AutoML.Test
{
    public class StopTrainingManagerTests : BaseTestClass
    {
        public StopTrainingManagerTests(ITestOutputHelper output)
            : base(output)
        {
        }
 
        [Fact]
        public void CancellationTokenStopTrainingManager_isStopTrainingRequested_test()
        {
            var cts = new CancellationTokenSource();
            var manager = new CancellationTokenStopTrainingManager(cts.Token, null);
 
            var isOnStopTrainingGetInvoked = false;
            manager.OnStopTraining += (o, e) =>
            {
                isOnStopTrainingGetInvoked = true;
            };
 
            manager.IsStopTrainingRequested().Should().BeFalse();
 
            cts.Cancel();
 
            manager.IsStopTrainingRequested().Should().BeTrue();
            isOnStopTrainingGetInvoked.Should().BeTrue();
        }
 
        [Fact]
        public async Task TimeoutTrainingStopManager_isStopTrainingRequested_test()
        {
            var manager = new TimeoutTrainingStopManager(TimeSpan.FromSeconds(1), null);
 
            var isOnStopTrainingGetInvoked = false;
            manager.OnStopTraining += (o, e) =>
            {
                isOnStopTrainingGetInvoked = true;
            };
 
            var waitForCompleted = Task.Run(async () =>
            {
                while (!isOnStopTrainingGetInvoked)
                {
                    await Task.Delay(1000);
                }
            });
 
            await waitForCompleted;
            manager.IsStopTrainingRequested().Should().BeTrue();
            isOnStopTrainingGetInvoked.Should().BeTrue();
        }
 
        [Fact]
        public async Task AggregateTrainingStopManager_isStopTrainingRequested_test()
        {
            var cts = new CancellationTokenSource();
            var timeoutManager = new TimeoutTrainingStopManager(TimeSpan.FromSeconds(1), null);
            var cancellationManager = new CancellationTokenStopTrainingManager(cts.Token, null);
            var aggregationManager = new AggregateTrainingStopManager(null, timeoutManager, cancellationManager);
 
            var isOnStopTrainingGetInvoked = false;
            aggregationManager.OnStopTraining += (o, e) =>
            {
                isOnStopTrainingGetInvoked = true;
            };
 
            var waitForCompleted = Task.Run(async () =>
            {
                while (!isOnStopTrainingGetInvoked)
                {
                    await Task.Delay(1000);
                }
            });
 
            await waitForCompleted;
            aggregationManager.IsStopTrainingRequested().Should().BeTrue();
            timeoutManager.IsStopTrainingRequested().Should().BeTrue();
            cancellationManager.IsStopTrainingRequested().Should().BeFalse();
            isOnStopTrainingGetInvoked.Should().BeTrue();
 
            aggregationManager = new AggregateTrainingStopManager(null, new TimeoutTrainingStopManager(TimeSpan.FromSeconds(100), null), cancellationManager);
            isOnStopTrainingGetInvoked = false;
            aggregationManager.OnStopTraining += (o, e) =>
            {
                isOnStopTrainingGetInvoked = true;
            };
            cts.Cancel();
 
            await waitForCompleted;
            aggregationManager.IsStopTrainingRequested().Should().BeTrue();
            cancellationManager.IsStopTrainingRequested().Should().BeTrue();
            isOnStopTrainingGetInvoked.Should().BeTrue();
        }
    }
}