File: PoolLoader.cs
Web Access
Project: src\src\Microsoft.Extensions.ML\Microsoft.Extensions.ML.csproj (Microsoft.Extensions.ML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Threading;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.ObjectPool;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.ML;
 
namespace Microsoft.Extensions.ML
{
    /// <summary>
    /// Encapsulates the data and logic required for loading and reloading PredictionEngine object pools.
    /// </summary>
    internal class PoolLoader<TData, TPrediction> : IDisposable
        where TData : class
        where TPrediction : class, new()
    {
        private DefaultObjectPool<PredictionEngine<TData, TPrediction>> _pool;
        private readonly IDisposable _changeTokenRegistration;
 
        public PoolLoader(IServiceProvider sp, PredictionEnginePoolOptions<TData, TPrediction> poolOptions)
        {
            var contextOptions = sp.GetRequiredService<IOptions<MLOptions>>();
            Context = contextOptions.Value.MLContext ?? throw new ArgumentNullException(nameof(contextOptions));
            Loader = poolOptions.ModelLoader ?? throw new ArgumentNullException(nameof(poolOptions));
 
            LoadPool();
 
            _changeTokenRegistration = ChangeToken.OnChange(
                () => Loader.GetReloadToken(),
                () => LoadPool());
        }
 
        public ModelLoader Loader { get; }
        private MLContext Context { get; }
        public ObjectPool<PredictionEngine<TData, TPrediction>> PredictionEnginePool { get { return _pool; } }
 
        private void LoadPool()
        {
            var predictionEnginePolicy = new PredictionEnginePoolPolicy<TData, TPrediction>(Context, Loader.GetModel());
            Interlocked.Exchange(ref _pool, new DefaultObjectPool<PredictionEngine<TData, TPrediction>>(predictionEnginePolicy));
        }
 
        public void Dispose()
        {
            _changeTokenRegistration?.Dispose();
        }
    }
}