File: ModelLoaders\FileModelLoader.cs
Web Access
Project: src\src\Microsoft.Extensions.ML\Microsoft.Extensions.ML.csproj (Microsoft.Extensions.ML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Diagnostics;
using System.IO;
using System.Threading;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.ML;
 
namespace Microsoft.Extensions.ML
{
    internal class FileModelLoader : ModelLoader, IDisposable
    {
        private readonly ILogger<FileModelLoader> _logger;
        private string _filePath;
        private FileSystemWatcher _watcher;
        private ModelReloadToken _reloadToken;
        private ITransformer _model;
 
        private readonly MLContext _context;
 
        private readonly object _lock;
 
        public FileModelLoader(IOptions<MLOptions> contextOptions, ILogger<FileModelLoader> logger)
        {
            _logger = logger ?? throw new ArgumentNullException(nameof(logger));
            _context = contextOptions.Value?.MLContext ?? throw new ArgumentNullException(nameof(contextOptions));
            _lock = new object();
        }
 
        public void Start(string filePath, bool watchFile)
        {
            _filePath = filePath;
            _reloadToken = new ModelReloadToken();
 
            if (!File.Exists(filePath))
            {
                throw new ArgumentException($"The provided model file {filePath} doesn't exist.");
            }
 
            var directory = Path.GetDirectoryName(filePath);
 
            if (string.IsNullOrEmpty(directory))
            {
                directory = Directory.GetCurrentDirectory();
            }
 
            var file = Path.GetFileName(filePath);
 
            LoadModel();
 
            if (watchFile)
            {
                _watcher = new FileSystemWatcher(directory, file);
                _watcher.EnableRaisingEvents = true;
                _watcher.Changed += WatcherChanged;
            }
        }
 
        private void WatcherChanged(object sender, FileSystemEventArgs e)
        {
            var timer = Stopwatch.StartNew();
 
            try
            {
                Logger.FileReloadBegin(_logger, _filePath);
 
                var previousToken = Interlocked.Exchange(ref _reloadToken, new ModelReloadToken());
                lock (_lock)
                {
                    //TODO: We get here multiple times when you copy and paste a file
                    //because of the way file watchers work. Need to think through the
                    //ramifications.
                    LoadModel();
                    Logger.ReloadingFile(_logger, _filePath, timer.Elapsed);
                }
                previousToken.OnReload();
                timer.Stop();
                Logger.FileReloadEnd(_logger, _filePath, timer.Elapsed);
            }
            catch (OperationCanceledException)
            {
                // This is a cancellation - if the app is shutting down we want to ignore it.
            }
            catch (Exception ex)
            {
                Logger.FileReloadError(_logger, _filePath, timer.Elapsed, ex);
            }
        }
 
        public override IChangeToken GetReloadToken()
        {
            if (_reloadToken == null) throw new InvalidOperationException("Start must be called on a ModelLoader before it can be used.");
            return _reloadToken;
        }
 
        public override ITransformer GetModel()
        {
            if (_model == null) throw new InvalidOperationException("Start must be called on a ModelLoader before it can be used.");
 
            return _model;
        }
 
        private FileStream WaitForFile(string fullPath, FileMode mode, FileAccess access, FileShare share)
        {
            for (int numTries = 0; numTries < 100; numTries++)
            {
                FileStream fs = null;
                try
                {
                    fs = new FileStream(fullPath, mode, access, share);
                    return fs;
                }
                catch (IOException)
                {
                    if (fs != null)
                    {
                        fs.Dispose();
                    }
                    Thread.Sleep(50);
                }
            }
 
            return null;
        }
 
        //internal virtual for testing purposes.
        internal virtual void LoadModel()
        {
            var fs = WaitForFile(_filePath, FileMode.Open, FileAccess.Read, FileShare.Read);
            if (fs == null)
                throw new IOException($"Model file {_filePath} still got locked after 5 seconds, fail to reload.");
 
            using (fs)
            {
                _model = _context.Model.Load(fs, out _);
            }
        }
 
        public void Dispose()
        {
            _watcher?.Dispose();
        }
 
        internal static class EventIds
        {
            public static readonly EventId FileReloadBegin = new EventId(100, "FileReloadBegin");
            public static readonly EventId FileReloadEnd = new EventId(101, "FileReloadEnd");
            public static readonly EventId FileReload = new EventId(102, "FileReload");
            public static readonly EventId FileReloadError = new EventId(103, nameof(FileReloadError));
        }
 
        private static class Logger
        {
            private static readonly Action<ILogger, string, Exception> _fileLoadBegin = LoggerMessage.Define<string>(
                LogLevel.Debug,
                EventIds.FileReloadBegin,
                "File reload for '{filePath}'");
 
            private static readonly Action<ILogger, string, double, Exception> _fileLoadEnd = LoggerMessage.Define<string, double>(
                LogLevel.Debug,
                EventIds.FileReloadEnd,
                "File reload for '{filePath}' completed after {ElapsedMilliseconds}ms");
 
            private static readonly Action<ILogger, string, double, Exception> _fileReloadError = LoggerMessage.Define<string, double>(
                LogLevel.Error,
                EventIds.FileReloadError,
                "File reload for '{filePath}' threw an unhandled exception after {ElapsedMilliseconds}ms");
 
            private static readonly Action<ILogger, string, double, Exception> _fileReLoad = LoggerMessage.Define<string, double>(
                LogLevel.Information,
                EventIds.FileReloadEnd,
                "Reloading file '{filePath}' completed after {ElapsedMilliseconds}ms");
 
            public static void FileReloadBegin(ILogger logger, string filePath)
            {
                _fileLoadBegin(logger, filePath, null);
            }
 
            public static void FileReloadEnd(ILogger logger, string filePath, TimeSpan duration)
            {
                _fileLoadEnd(logger, filePath, duration.TotalMilliseconds, null);
            }
 
            public static void FileReloadError(ILogger logger, string filePath, TimeSpan duration, Exception exception)
            {
                _fileReloadError(logger, filePath, duration.TotalMilliseconds, exception);
            }
 
            public static void ReloadingFile(ILogger logger, string filePath, TimeSpan duration)
            {
                _fileReLoad(logger, filePath, duration.TotalMilliseconds, null);
            }
        }
 
    }
}