File: Utilities\ResourceManagerUtils.cs
Web Access
Project: src\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj (Microsoft.ML.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Internal.Utilities
{
    /// <summary>
    /// This class takes care of downloading resources needed by ML.NET components. Resources are located in
    /// a pre-defined location, that can be overridden by defining Environment variable <see cref="CustomResourcesUrlEnvVariable"/>.
    /// </summary>
    [BestFriend]
    internal sealed class ResourceManagerUtils
    {
        private static volatile ResourceManagerUtils _instance;
        public static ResourceManagerUtils Instance
        {
            get
            {
                return _instance ??
                    Interlocked.CompareExchange(ref _instance, new ResourceManagerUtils(), null) ??
                    _instance;
            }
        }
 
        private const string DefaultUrl = "https://aka.ms/mlnet-resources/";
        /// <summary>
        /// The location of the resources to download from. Uses either the default location or a location defined
        /// in an Environment variable.
        /// </summary>
        private static string MlNetResourcesUrl
        {
            get
            {
                var envUrl = Environment.GetEnvironmentVariable(CustomResourcesUrlEnvVariable);
                if (!string.IsNullOrEmpty(envUrl))
                    return envUrl;
                return DefaultUrl;
            }
        }
 
        /// <summary>
        /// An environment variable containing a timeout period (in milliseconds) for downloading resources. If defined,
        /// overrides the timeout defined in the code.
        /// </summary>
        public const string TimeoutEnvVariable = "MICROSOFTML_RESOURCE_TIMEOUT";
 
        /// <summary>
        /// Environment variable containing optional url to download resources from.
        /// </summary>
        public const string CustomResourcesUrlEnvVariable = "MICROSOFTML_RESOURCE_URL";
 
        public sealed class ResourceDownloadResults
        {
            public readonly string FileName;
            internal readonly string ErrorMessage;
            internal readonly string DownloadUrl;
 
            public ResourceDownloadResults(string fileName, string errorMessage, string downloadUrl = null)
            {
                FileName = fileName;
                ErrorMessage = errorMessage;
                DownloadUrl = downloadUrl;
            }
        }
 
        private ResourceManagerUtils()
        {
        }
 
        /// <summary>
        /// Generates a url from a suffix.
        /// </summary>
        public static string GetUrl(string suffix)
        {
            return $"{MlNetResourcesUrl}{suffix}";
        }
 
        /// <summary>
        /// Returns a <see cref="Task"/> that tries to download a resource from a specified url, and returns the path to which it was
        /// downloaded, and an exception if one was thrown.
        /// </summary>
        /// <remarks>
        /// The function <see cref="ResourceManagerUtils.DownloadResource"/> checks whether or not the absolute URL with the
        /// default host "aka.ms" formed from <paramref name="relativeUrl"/> redirects to the default Microsoft homepage.
        /// As such, only absolute URLs with the host "aka.ms" is supported with <see cref="ResourceManagerUtils.EnsureResourceAsync"/>.
        /// </remarks>
        /// <param name="env">The host environment.</param>
        /// <param name="ch">A channel to provide information about the download.</param>
        /// <param name="relativeUrl">The relative url from which to download.
        /// This is appended to the url defined in <see cref="MlNetResourcesUrl"/>.</param>
        /// <param name="fileName">The name of the file to save.</param>
        /// <param name="dir">The directory where the file should be saved to. The file will be saved in a directory with the specified name inside
        /// a folder called "mlnet-resources" in the <see cref="Environment.SpecialFolder.ApplicationData"/> directory.</param>
        /// <param name="timeout">An integer indicating the number of milliseconds to wait before timing out while downloading a resource.</param>
        /// <returns>The download results, containing the file path where the resources was (or should have been) downloaded to, and an error message
        /// (or null if there was no error).</returns>
        public async Task<ResourceDownloadResults> EnsureResourceAsync(IHostEnvironment env, IChannel ch, string relativeUrl, string fileName, string dir, int timeout)
        {
            var filePath = GetFilePath(ch, fileName, dir, out var error);
            if (File.Exists(filePath) || !string.IsNullOrEmpty(error))
                return new ResourceDownloadResults(filePath, error);
 
            if (!Uri.TryCreate(Path.Combine(MlNetResourcesUrl, relativeUrl), UriKind.Absolute, out var absoluteUrl))
            {
                return new ResourceDownloadResults(filePath,
                    $"Could not create a valid URI from the base URI '{MlNetResourcesUrl}' and the relative URI '{relativeUrl}'");
            }
            if (absoluteUrl.Host != "aka.ms")
                throw new NotSupportedException("The function ResourceManagerUtils.EnsureResourceAsync only supports downloading from URLs of the host \"aka.ms\"");
            return new ResourceDownloadResults(filePath,
                await DownloadFromUrlWithRetryAsync(env, ch, absoluteUrl.AbsoluteUri, fileName, timeout, filePath), absoluteUrl.AbsoluteUri);
        }
 
        private async Task<string> DownloadFromUrlWithRetryAsync(IHostEnvironment env, IChannel ch, string url, string fileName,
            int timeout, string filePath, int retryTimes = 5)
        {
            var downloadResult = "";
 
            for (int i = 0; i < retryTimes; ++i)
            {
                try
                {
                    var thisDownloadResult = await DownloadFromUrlAsync(env, ch, url, fileName, timeout, filePath);
 
                    if (string.IsNullOrEmpty(thisDownloadResult))
                        return thisDownloadResult;
                    else
                        downloadResult += thisDownloadResult + @"\n";
 
                    await Task.Delay(10 * 1000);
                }
                catch (Exception ex)
                {
                    downloadResult += $"DownloadFailed with exception {ex.Message}" + @"\n";
                    // ignore any Exception and retrying download
                    ch.Warning($"{i + 1} - th try: Dowload {fileName} from {url} fail with exception {ex.Message}");
                }
            }
 
            return downloadResult;
        }
 
        /// <returns>Returns the error message if an error occurred, null if download was successful.</returns>
        private async Task<string> DownloadFromUrlAsync(IHostEnvironment env, IChannel ch, string url, string fileName, int timeout, string filePath)
        {
            using (var client = new HttpClient())
            using (var downloadCancel = new CancellationTokenSource())
            {
                var t = Task.Run(() => DownloadResource(env, ch, client, new Uri(url), filePath, fileName, downloadCancel.Token));
 
                UpdateTimeout(ref timeout);
                var timeoutTask = Task.Delay(timeout).ContinueWith(task => default(Exception), TaskScheduler.Default);
                ch.Info($"Downloading {fileName} from {url} to {filePath}");
                var completedTask = await Task.WhenAny(t, timeoutTask);
                if (completedTask != t || completedTask.CompletedResult() != null)
                {
                    downloadCancel.Cancel();
                    if (File.Exists(filePath))
                        TryDelete(ch, filePath);
                    return (await t).Message;
                }
                return null;
            }
        }
 
        private static void TryDelete(IChannel ch, string filePath, bool warn = true)
        {
            try
            {
                File.Delete(filePath);
            }
            catch (Exception e)
            {
                if (warn)
                    ch.Warning($"File '{filePath}' could not be deleted: {e.Message}");
            }
        }
 
        private static void UpdateTimeout(ref int timeout)
        {
            var envTimeout = Environment.GetEnvironmentVariable(TimeoutEnvVariable);
            if (!string.IsNullOrWhiteSpace(envTimeout) && int.TryParse(envTimeout, out var res))
                timeout = res;
        }
 
        /// <summary>
        /// Get the path where the resource should be downloaded to. If the environment variable
        /// is defined, download to the location defined there. Otherwise, download to the "dir" directory
        /// inside <see cref="Environment.SpecialFolder.LocalApplicationData"/>\mlnet-resources\.
        /// </summary>
        private static string GetFilePath(IChannel ch, string fileName, string dir, out string error)
        {
            var envDir = Environment.GetEnvironmentVariable(Utils.CustomSearchDirEnvVariable);
            var appDataBaseDir = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData);
            var appDataDir = Path.Combine(appDataBaseDir, "mlnet-resources");
            var absDir = Path.Combine(string.IsNullOrEmpty(envDir) ? appDataDir : envDir, dir);
            var filePath = Path.Combine(absDir, fileName);
            error = null;
 
            if (!Directory.Exists(appDataBaseDir) && string.IsNullOrEmpty(envDir))
            {
                try
                {
                    Directory.CreateDirectory(appDataBaseDir);
 
                    // On unix, create with 0700 perms as per XDG base dir spec
                    if (Environment.OSVersion.Platform == PlatformID.Unix)
                        chmod(appDataBaseDir, 448);
                }
                catch (Exception e)
                {
                    error = $"Error trying to create directory {appDataBaseDir}: {e.Message}.\nPlease fix your " +
                        "filesystem permissions, or try setting the " +
                        $"'{Utils.CustomSearchDirEnvVariable}' environment variable to a writable folder";
                    return filePath;
                }
            }
 
            // TODO: Also confirm write permission on the directory (maybe use Security.Permissions.FileIOPermission)
            if (!Directory.Exists(absDir))
            {
                try
                {
                    Directory.CreateDirectory(absDir);
                    // On unix, create with 0700 perms as per XDG base dir spec
                    if (Environment.OSVersion.Platform == PlatformID.Unix)
                        chmod(appDataBaseDir, 448);
                }
                catch (Exception e)
                {
                    error = $"Error trying to create directory {absDir}: {e.Message}.\nPlease try setting the " +
                        $"'{Utils.CustomSearchDirEnvVariable}' environment variable to a writable folder";
                }
            }
            return filePath;
        }
 
        private async Task<Exception> DownloadResource(IHostEnvironment env, IChannel ch, HttpClient httpClient, Uri uri, string path, string fileName, CancellationToken ct)
        {
            if (File.Exists(path))
                return null;
 
            var mutex = new Mutex(false, "Resource" + fileName);
            mutex.WaitOne();
            if (File.Exists(path))
            {
                mutex.ReleaseMutex();
                return null;
            }
 
            Guid guid = Guid.NewGuid();
            string tempPath = Path.GetFullPath(Path.Combine(Path.GetDirectoryName(path), "temp-resource-" + guid.ToString()));
            try
            {
                int blockSize = 4096;
 
                var response = await httpClient.GetAsync(uri, ct).ConfigureAwait(false);
                using (var fh = env.CreateOutputFile(tempPath))
                using (var ws = fh.CreateWriteStream())
                {
                    response.EnsureSuccessStatusCode();
                    IEnumerable<string> headers;
                    var hasHeader = response.Headers.TryGetValues("content-length", out headers);
                    if (uri.Host == "aka.ms" && IsRedirectToDefaultPage(uri.AbsoluteUri))
                        throw new NotSupportedException($"The provided url ({uri}) redirects to the default url ({DefaultUrl})");
                    if (!hasHeader || !long.TryParse(headers.First(), out var size))
                        size = 10000000;
 
                    var stream = await response.EnsureSuccessStatusCode().Content.ReadAsStreamAsync().ConfigureAwait(false);
 
                    await stream.CopyToAsync(ws, blockSize, ct);
 
                    if (ct.IsCancellationRequested)
                    {
                        ch.Error($"{fileName}: Download timed out");
                        return ch.Except("Download timed out");
                    }
                }
                File.Move(tempPath, path);
                ch.Info($"{fileName}: Download complete");
                return null;
            }
            catch (WebException e)
            {
                ch.Error($"{fileName}: Could not download. HttpClient returned the following error: {e.Message}");
                return e;
            }
            finally
            {
                TryDelete(ch, tempPath, warn: false);
                mutex.ReleaseMutex();
            }
        }
 
        /// <summary>This method checks whether or not the provided aka.ms url redirects to
        /// Microsoft's homepage, as the default faulty aka.ms URLs redirect to https://www.microsoft.com/?ref=aka .</summary>
        /// <param name="url"> The provided url to check </param>
        public bool IsRedirectToDefaultPage(string url)
        {
            try
            {
                var request = WebRequest.Create(url);
                // FileWebRequests cannot be redirected to default aka.ms webpage
                if (request.GetType() == typeof(FileWebRequest))
                    return false;
                HttpWebRequest httpWebRequest = (HttpWebRequest)request;
                httpWebRequest.AllowAutoRedirect = false;
                HttpWebResponse httpWebResponse = (HttpWebResponse)httpWebRequest.GetResponse();
            }
            catch (WebException e)
            {
                HttpWebResponse webResponse = (HttpWebResponse)e.Response;
                // Redirects to default url
                if (webResponse.StatusCode == HttpStatusCode.Redirect && webResponse.Headers["Location"] == "https://www.microsoft.com/?ref=aka")
                    return true;
                // Redirects to another url
                else if (webResponse.StatusCode == HttpStatusCode.MovedPermanently)
                    return false;
                else
                    return false;
            }
            return false;
        }
 
        public static ResourceDownloadResults GetErrorMessage(out string errorMessage, params ResourceDownloadResults[] result)
        {
            var errorResult = result.FirstOrDefault(res => !string.IsNullOrEmpty(res.ErrorMessage));
            if (errorResult == null)
                errorMessage = null;
            else if (string.IsNullOrEmpty(errorResult.DownloadUrl))
                errorMessage = $"Error downloading resource: {errorResult.ErrorMessage}";
            else
                errorMessage = $"Error downloading resource from '{errorResult.DownloadUrl}': {errorResult.ErrorMessage}";
            return errorResult;
        }
 
#pragma warning disable IDE1006
        [DllImport("libc", SetLastError = true)]
        private static extern int chmod(string pathname, int mode);
#pragma warning restore IDE1006
    }
}