File: Scenarios\Api\Estimators\MultithreadedPrediction.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Threading.Tasks;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Trainers;
using Xunit;
 
namespace Microsoft.ML.Tests.Scenarios.Api
{
    public partial class ApiScenariosTests
    {
        /// <summary>
        /// Multi-threaded prediction. A twist on "Simple train and predict", where we account that
        /// multiple threads may want predictions at the same time. Because we deliberately do not
        /// reallocate internal memory buffers on every single prediction, the PredictionEngine
        /// (or its estimator/transformer based successor) is, like most stateful .NET objects,
        /// fundamentally not thread safe. This is deliberate and as designed. However, some mechanism
        /// to enable multi-threaded scenarios (for example, a web server servicing requests) should be possible
        /// and performant in the new API.
        /// </summary>
        [Fact]
        public void MultithreadedPrediction()
        {
            var ml = new MLContext(seed: 1);
            var data = ml.Data.LoadFromTextFile<SentimentData>(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true);
 
            // Pipeline.
            var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText")
                .AppendCacheCheckpoint(ml)
                .Append(ml.BinaryClassification.Trainers.SdcaNonCalibrated(
                    new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1 }));
 
            // Train.
            var model = pipeline.Fit(data);
 
            // Create prediction engine and test predictions.
            var engine = ml.Model.CreatePredictionEngine<SentimentData, SentimentPrediction>(model);
 
            // Take a couple examples out of the test data and run predictions on top.
            var testData = ml.Data.CreateEnumerable<SentimentData>(
                ml.Data.LoadFromTextFile<SentimentData>(GetDataPath(TestDatasets.Sentiment.testFilename), hasHeader: true), false);
 
            Parallel.ForEach(testData, (input) =>
            {
                lock (engine)
                {
                    var prediction = engine.Predict(input);
                }
            });
        }
    }
}