File: FileLoaderTests.cs
Web Access
Project: src\test\Microsoft.Extensions.ML.Tests\Microsoft.Extensions.ML.Tests.csproj (Microsoft.Extensions.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using System.Threading;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Microsoft.ML.TestFrameworkCommon;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.Extensions.ML
{
    public class FileLoaderTests : BaseTestClass
    {
        public FileLoaderTests(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        public void throw_until_started()
        {
            var services = new ServiceCollection()
                .AddOptions()
                .AddLogging();
            var sp = services.BuildServiceProvider();
 
            var loaderUnderTest = ActivatorUtilities.CreateInstance<FileModelLoader>(sp);
            Assert.Throws<InvalidOperationException>(() => loaderUnderTest.GetModel());
            Assert.Throws<InvalidOperationException>(() => loaderUnderTest.GetReloadToken());
        }
 
        [Fact]
        public void can_load_model()
        {
            var services = new ServiceCollection()
                .AddOptions()
                .AddLogging();
            var sp = services.BuildServiceProvider();
 
            var loaderUnderTest = ActivatorUtilities.CreateInstance<FileModelLoader>(sp);
            loaderUnderTest.Start(Path.Combine("TestModels", "SentimentModel.zip"), false);
 
            var model = loaderUnderTest.GetModel();
            var context = sp.GetRequiredService<IOptions<MLOptions>>().Value.MLContext;
            var engine = context.Model.CreatePredictionEngine<SentimentData, SentimentPrediction>(model);
 
            var prediction = engine.Predict(new SentimentData() { SentimentText = "This is great" });
            Assert.True(prediction.Sentiment);
        }
 
        [Fact]
        public void can_reload_model()
        {
            var services = new ServiceCollection()
                .AddOptions()
                .AddLogging();
            var sp = services.BuildServiceProvider();
 
            var loaderUnderTest = ActivatorUtilities.CreateInstance<FileLoaderMock>(sp);
            loaderUnderTest.Start("testdata.txt", true);
 
            using AutoResetEvent changed = new AutoResetEvent(false);
            using IDisposable changeTokenRegistration = ChangeToken.OnChange(
                        () => loaderUnderTest.GetReloadToken(),
                        () => changed.Set());
 
            File.WriteAllText("testdata.txt", "test");
 
            Assert.True(changed.WaitOne(AsyncTestHelper.UnexpectedTimeout), "FileLoader ChangeToken didn't fire before the allotted time.");
        }
 
 
        private class FileLoaderMock : FileModelLoader
        {
            public FileLoaderMock(IOptions<MLOptions> contextOptions, ILogger<FileModelLoader> logger)
                : base(contextOptions, logger)
            {
            }
 
            internal override void LoadModel()
            {
            }
        }
 
        public class SentimentData
        {
            [ColumnName("Label"), LoadColumn(0)]
            public bool Sentiment;
 
            [LoadColumn(1)]
            public string SentimentText;
        }
 
        public class SentimentPrediction
        {
            [ColumnName("PredictedLabel")]
            public bool Sentiment;
 
            public float Score;
        }
    }
}