File: ModelLoaders\UriModelLoader.cs
Web Access
Project: src\src\Microsoft.Extensions.ML\Microsoft.Extensions.ML.csproj (Microsoft.Extensions.ML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Diagnostics;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.ML;
 
namespace Microsoft.Extensions.ML
{
    internal class UriModelLoader : ModelLoader, IDisposable
    {
        //TODO: This should be able to be removed for HeaderNames.ETag
        private const string ETagHeader = "ETag";
        private const int TimeoutMilliseconds = 60000;
        private readonly MLContext _context;
        private TimeSpan? _timerPeriod;
        private Uri _uri;
        private ITransformer _model;
        private ModelReloadToken _reloadToken;
        private Timer _reloadTimer;
        private readonly object _reloadTimerLock;
        private string _eTag;
        private readonly ILogger _logger;
        private readonly CancellationTokenSource _stopping;
        private bool _started;
 
        public UriModelLoader(IOptions<MLOptions> contextOptions, ILogger<UriModelLoader> logger)
        {
            _logger = logger ?? throw new ArgumentNullException(nameof(logger));
            _context = contextOptions.Value?.MLContext;
            _reloadTimerLock = new object();
            _reloadToken = new ModelReloadToken();
            _stopping = new CancellationTokenSource();
            _started = false;
        }
 
        internal void Start(Uri uri, TimeSpan period)
        {
            _timerPeriod = period;
            _uri = uri;
            if (LoadModelAsync().ConfigureAwait(false).GetAwaiter().GetResult())
            {
                StartReloadTimer();
            }
            _started = true;
        }
 
        private void ReloadTimerTick(object state)
        {
            _ = Task.Run(async () =>
            {
                StopReloadTimer();
 
                await RunAsync();
 
                StartReloadTimer();
            });
        }
 
        internal bool IsStopping => _stopping.IsCancellationRequested;
 
        internal async Task RunAsync()
        {
            CancellationTokenSource cancellation = null;
            //TODO: Switch to ValueStopWatch
            var duration = Stopwatch.StartNew();
            try
            {
                cancellation = CancellationTokenSource.CreateLinkedTokenSource(_stopping.Token);
                cancellation.CancelAfter(TimeoutMilliseconds);
                Logger.UriReloadBegin(_logger, _uri);
 
                var eTagMatches = await MatchEtagAsync(_uri, _eTag);
                if (!eTagMatches)
                {
                    await LoadModelAsync();
                    var previousToken = Interlocked.Exchange(ref _reloadToken, new ModelReloadToken());
                    previousToken.OnReload();
                }
 
                Logger.UriReloadEnd(_logger, _uri, duration.Elapsed);
            }
            catch (OperationCanceledException) when (IsStopping)
            {
                // This is a cancellation - if the app is shutting down we want to ignore it.
            }
            catch (Exception ex)
            {
                Logger.UriReloadError(_logger, _uri, duration.Elapsed, ex);
            }
            finally
            {
                cancellation.Dispose();
            }
        }
 
        internal virtual async Task<bool> MatchEtagAsync(Uri uri, string eTag)
        {
            using (var client = new HttpClient())
            {
                var headRequest = new HttpRequestMessage(HttpMethod.Head, uri);
                var resp = await client.SendAsync(headRequest);
 
                return resp.Headers.GetValues(ETagHeader).First() == eTag;
            }
        }
 
        internal void StartReloadTimer()
        {
            lock (_reloadTimerLock)
            {
                if (_reloadTimer == null)
                {
                    _reloadTimer = new Timer(ReloadTimerTick, this, Convert.ToInt32(_timerPeriod.Value.TotalMilliseconds), Timeout.Infinite);
                }
            }
        }
 
        internal void StopReloadTimer()
        {
            lock (_reloadTimerLock)
            {
                _reloadTimer.Dispose();
                _reloadTimer = null;
            }
        }
 
        internal virtual async Task<bool> LoadModelAsync()
        {
            //TODO: We probably need some sort of retry policy for this.
            try
            {
                using (var client = new HttpClient())
                {
                    var resp = await client.GetAsync(_uri);
                    using (var stream = await resp.Content.ReadAsStreamAsync())
                    {
                        _model = _context.Model.Load(stream, out _);
                    }
 
                    if (resp.Headers.Contains(ETagHeader))
                    {
                        _eTag = resp.Headers.GetValues(ETagHeader).First();
                        return true;
                    }
                    return false;
                }
            }
            catch (Exception ex)
            {
                Logger.UriLoadError(_logger, _uri, ex);
                throw;
            }
        }
 
        public override ITransformer GetModel()
        {
            if (!_started) throw new InvalidOperationException("Start must be called on a ModelLoader before it can be used.");
 
            return _model;
        }
 
        public override IChangeToken GetReloadToken()
        {
            if (!_started) throw new InvalidOperationException("Start must be called on a ModelLoader before it can be used.");
 
            return _reloadToken;
        }
 
        public void Dispose()
        {
            _reloadTimer?.Dispose();
        }
 
        internal static class EventIds
        {
            public static readonly EventId UriReloadBegin = new EventId(100, "UriReloadBegin");
            public static readonly EventId UriReloadEnd = new EventId(101, "UriReloadEnd");
            public static readonly EventId UriReloadError = new EventId(102, "UriReloadError");
            public static readonly EventId UriLoadError = new EventId(103, "UriLoadError");
        }
 
        private static class Logger
        {
            private static readonly Action<ILogger, Uri, Exception> _uriReloadBegin = LoggerMessage.Define<Uri>(
                LogLevel.Debug,
                EventIds.UriReloadBegin,
                "URI reload '{uri}'");
 
            private static readonly Action<ILogger, Uri, double, Exception> _uriReloadEnd = LoggerMessage.Define<Uri, double>(
                LogLevel.Debug,
                EventIds.UriReloadEnd,
                "URI reload '{uri}' completed after {ElapsedMilliseconds}ms");
 
            private static readonly Action<ILogger, Uri, double, Exception> _uriReloadError = LoggerMessage.Define<Uri, double>(
                LogLevel.Error,
                EventIds.UriReloadError,
                "URI reload for {uri} threw an unhandled exception after {ElapsedMilliseconds}ms");
 
            private static readonly Action<ILogger, Uri, Exception> _uriLoadError = LoggerMessage.Define<Uri>(
                LogLevel.Error,
                EventIds.UriLoadError,
                "Error loading {uri}");
 
            public static void UriReloadBegin(ILogger logger, Uri uri)
            {
                _uriReloadBegin(logger, uri, null);
            }
 
            public static void UriReloadEnd(ILogger logger, Uri uri, TimeSpan duration)
            {
                _uriReloadEnd(logger, uri, duration.TotalMilliseconds, null);
            }
 
            public static void UriReloadError(ILogger logger, Uri uri, TimeSpan duration, Exception exception)
            {
                _uriReloadError(logger, uri, duration.TotalMilliseconds, exception);
            }
 
            public static void UriLoadError(ILogger logger, Uri uri, Exception exception)
            {
                _uriLoadError(logger, uri, exception);
            }
        }
    }
}