File: Data\Repository.cs
Web Access
Project: src\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj (Microsoft.ML.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Reflection;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML;
 
/// <summary>
/// For saving a model into a repository.
/// Classes implementing <see cref="ICanSaveModel"/> should do an explicit implementation of <see cref="Save(ModelSaveContext)"/>.
/// Classes inheriting <see cref="ICanSaveModel"/> from a base class should overwrite the function invoked by <see cref="Save(ModelSaveContext)"/>
/// in that base class, if there is one.
/// </summary>
public interface ICanSaveModel
{
    void Save(ModelSaveContext ctx);
}
 
/// <summary>
/// For saving to a single stream. Note that this interface is mostly deprecated in favor of
/// saving more comprehensive and composable "model" objects, via <see cref="ICanSaveModel"/>.
/// </summary>
[BestFriend]
internal interface ICanSaveInBinaryFormat
{
    void SaveAsBinary(BinaryWriter writer);
}
 
/// <summary>
/// Abstraction around a <see cref="ZipArchive"/> or other hierarchical storage.
/// </summary>
[BestFriend]
internal abstract class Repository : IDisposable
{
    public sealed class Entry : IDisposable
    {
        // The parent repository.
        private Repository _rep;
 
        /// <summary>
        /// The relative path of this entry.
        /// /// </summary>
        public string Path { get; }
 
        /// <summary>
        /// The stream for this entry. This is either a memory stream or a file stream in
        /// the temporary directory. In either case, it is seekable and NOT the actual
        /// archive stream.
        /// </summary>
        public Stream Stream { get; }
 
        internal Entry(Repository rep, string path, Stream stream)
        {
            _rep = rep;
            Path = path;
            Stream = stream;
        }
 
        public void Dispose()
        {
            if (_rep != null)
            {
                // Tell the repository that we're disposed. Note that the repository "owns" the stream
                // so is in charge of closing it.
                _rep.OnDispose(this);
                _rep = null;
            }
        }
    }
 
    // These are the open entries that may contain streams into our DirTemp.
    private readonly List<Entry> _open;
 
    private bool _disposed;
 
    private readonly IExceptionContext _ectx;
 
    // This is a temporary directory that we create. It is essentially treated like an un-managed resource,
    // hence the need for the complete dispose pattern. Note that it is optional - if we use memory
    // streams for everything, we don't need it. This ability is needed for Scope or other environments
    // where access to the file system is restricted.
    protected readonly string DirTemp;
 
    // Maps from relative path to the corresponding absolute path in the temp directory.
    // This is populated as we decompress streams in the archive, so we don't de-compress
    // more than once.
    // REVIEW: Should we garbage collect to some degree? Currently we don't delete any
    // of these temp files until the repository is disposed.
    protected readonly ConcurrentDictionary<string, string> PathMap;
 
    /// <summary>
    /// Exception context.
    /// </summary>
    public IExceptionContext ExceptionContext => _ectx;
 
    protected bool Disposed => _disposed;
 
    internal Repository(bool needDir, IExceptionContext ectx)
    {
        Contracts.AssertValueOrNull(ectx);
        _ectx = ectx;
 
        PathMap = new ConcurrentDictionary<string, string>();
        _open = new List<Entry>();
        if (needDir)
            DirTemp = GetShortTempDir(ectx);
        else
            GC.SuppressFinalize(this);
    }
 
    private static string GetShortTempDir(IExceptionContext ectx)
    {
        string tempPath = ectx is IHostEnvironmentInternal iHostInternal ?
            iHostInternal.TempFilePath :
            Path.GetTempPath();
        int dirNumber = 0;
        string mlNetTempDir = null!;
        while (Directory.Exists(mlNetTempDir = Path.Combine(Path.GetFullPath(tempPath), $"ml_dotnet{dirNumber++}"))) ;
        var path = Path.Combine(mlNetTempDir, Path.GetRandomFileName());
        Directory.CreateDirectory(path);
        return path;
    }
 
    ~Repository()
    {
        if (!Disposed)
            Dispose(false);
    }
 
    public void Dispose()
    {
        if (!Disposed)
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }
    }
 
    protected virtual void Dispose(bool disposing)
    {
        _ectx.Assert(!Disposed);
 
        // Close all temp files.
        try
        {
            DisposeAllEntries();
        }
        catch
        {
            _ectx.Assert(false, "Closing entries should not throw!");
        }
 
        // Delete the temp directory.
        if (DirTemp != null)
        {
            try
            {
                Directory.Delete(DirTemp, true);
            }
            catch
            {
            }
        }
 
        _disposed = true;
    }
 
    /// <summary>
    /// Force all open entries to be disposed.
    /// </summary>
    protected void DisposeAllEntries()
    {
        while (_open.Count > 0)
        {
            var ent = _open[_open.Count - 1];
            ent.Dispose();
        }
    }
 
    /// <summary>
    /// Remove the entry from _open. Note that under normal access patterns, entries are LIFO,
    /// so we search from the end of _open.
    /// </summary>
    protected void RemoveEntry(Entry ent)
    {
        // Note that under normal access patterns, entries are LIFO, so we search from the end of _open.
        for (int i = _open.Count; --i >= 0;)
        {
            if (_open[i] == ent)
            {
                _open.RemoveAt(i);
                return;
            }
        }
        _ectx.Assert(false, "Why wasn't the entry found?");
    }
 
    /// <summary>
    /// The entry is being disposed. Note that overrides should always call RemoveEntry, in addition to whatever
    /// they need to do with the corresponding stream.
    /// </summary>
    protected abstract void OnDispose(Entry ent);
 
    /// <summary>
    /// When considering entries inside one of our model archives, we want to ensure that we
    /// use a consistent directory separator. Zip archives are stored as flat lists of entries.
    /// When we load those entries into our look-up dictionary, we normalize them to always use
    /// backward slashes.
    /// </summary>
    protected static string NormalizeForArchiveEntry(string path) => path?.Replace('/', Path.DirectorySeparatorChar);
 
    /// <summary>
    /// When building paths to our local file system, we want to force both forward and backward slashes
    /// to the system directory separator character. We do this for cases where we either used Windows-specific
    /// path building logic, or concatenated filesystem paths with zip archive entries on Linux.
    /// </summary>
    private static string NormalizeForFileSystem(string path) =>
        path?.Replace('/', Path.DirectorySeparatorChar).Replace('\\', Path.DirectorySeparatorChar);
 
    /// <summary>
    /// Constructs both the relative path to the entry and the absolute path of a corresponding
    /// temporary file. If createDir is true, makes sure the directory exists within the temp directory.
    /// </summary>
    protected void GetPath(out string pathEnt, out string pathTemp, string dir, string name, bool createDir)
    {
        _ectx.Assert(!Disposed);
        _ectx.CheckValueOrNull(dir);
        _ectx.CheckParam(dir == null || !dir.Contains(".."), nameof(dir));
        _ectx.CheckParam(!string.IsNullOrWhiteSpace(name), nameof(name));
        _ectx.CheckParam(!name.Contains(".."), nameof(name));
 
        // The gymnastics below are meant to deal with bad invocations including absolute paths, etc.
        // That's why we go through it even if DirTemp is null.
        string root = Path.GetFullPath(DirTemp ?? @"x:\dummy");
        string entityPath = Path.Combine(root, dir ?? "", name);
        entityPath = Path.GetFullPath(entityPath);
        string tempPath = Path.Combine(root, Path.GetRandomFileName());
        tempPath = Path.GetFullPath(tempPath);
 
        string parent = Path.GetDirectoryName(entityPath);
        _ectx.Check(parent != null);
        _ectx.Check(parent.StartsWith(root));
 
        int ichSplit = root.Length;
        _ectx.Check(entityPath.Length > ichSplit && entityPath[ichSplit] == Path.DirectorySeparatorChar);
 
        if (createDir && DirTemp != null && parent.Length > ichSplit)
            Directory.CreateDirectory(parent);
 
        // Get the relative path portion. This is the archive entry name.
        pathEnt = entityPath.Substring(ichSplit + 1);
        _ectx.Check(Utils.Size(pathEnt) > 0);
        _ectx.Check(entityPath == Path.Combine(root, pathEnt));
 
        // Set pathTemp to non-null iff _dirTemp is non-null.
        pathTemp = DirTemp != null ? tempPath : null;
 
        pathEnt = NormalizeForArchiveEntry(pathEnt);
        pathTemp = NormalizeForFileSystem(pathTemp);
    }
 
    protected Entry AddEntry(string pathEnt, Stream stream)
    {
        _ectx.Assert(!Disposed);
        _ectx.AssertValue(stream);
 
        var ent = new Entry(this, pathEnt, stream);
        _open.Add(ent);
        return ent;
    }
}
 
[BestFriend]
internal sealed class RepositoryWriter : Repository
{
    private const string DirTrainingInfo = "TrainingInfo";
 
    private ZipArchive _archive;
    private Queue<KeyValuePair<string, Stream>> _closed;
 
    public static RepositoryWriter CreateNew(Stream stream, IExceptionContext ectx = null, bool useFileSystem = true)
    {
        Contracts.CheckValueOrNull(ectx);
        ectx.CheckValue(stream, nameof(stream));
        var rep = new RepositoryWriter(stream, ectx, useFileSystem);
 
        using (var ent = rep.CreateEntry(DirTrainingInfo, "Version.txt"))
        using (var writer = Utils.OpenWriter(ent.Stream))
            writer.WriteLine(GetProductVersion());
        return rep;
    }
 
    private RepositoryWriter(Stream stream, IExceptionContext ectx, bool useFileSystem = true)
        : base(useFileSystem, ectx)
    {
        _archive = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen: true);
        _closed = new Queue<KeyValuePair<string, Stream>>();
    }
 
    public Entry CreateEntry(string name)
    {
        return CreateEntry(null, name);
    }
 
    public Entry CreateEntry(string dir, string name)
    {
        ExceptionContext.Check(!Disposed);
 
        Flush();
 
        string pathEnt;
        string pathTemp;
        GetPath(out pathEnt, out pathTemp, dir, name, true);
        if (!PathMap.TryAdd(pathEnt, pathTemp))
            throw ExceptionContext.ExceptParam(nameof(name), "Duplicate entry: '{0}'", pathEnt);
 
        Stream stream;
        if (pathTemp != null)
            stream = new FileStream(pathTemp, FileMode.CreateNew);
        else
            stream = new MemoryStream();
 
        return AddEntry(pathEnt, stream);
    }
 
    // The entry is being disposed. Note that this isn't supposed to throw, so we simply queue
    // the stream so it can be written to the archive when it IS legal to throw.
    protected override void OnDispose(Entry ent)
    {
        ExceptionContext.AssertValue(ent);
        RemoveEntry(ent);
 
        if (_closed != null)
            _closed.Enqueue(new KeyValuePair<string, Stream>(ent.Path, ent.Stream));
        else
            ent.Stream.CloseEx();
    }
 
    protected override void Dispose(bool disposing)
    {
        ExceptionContext.Assert(!Disposed);
 
        if (_closed != null)
        {
            while (_closed.Count > 0)
            {
                var kvp = _closed.Dequeue();
                kvp.Value.CloseEx();
            }
            _closed = null;
        }
 
        if (_archive != null)
        {
            try
            {
                _archive.Dispose();
            }
            catch
            {
            }
            _archive = null;
        }
 
        // Close all the streams.
        base.Dispose(disposing);
    }
 
    // Write "closed" entries to the archive.
    private void Flush()
    {
        ExceptionContext.Assert(!Disposed);
        ExceptionContext.AssertValue(_closed);
        ExceptionContext.AssertValue(_archive);
 
        while (_closed.Count > 0)
        {
            string path = null;
            var kvp = _closed.Dequeue();
            using (var src = kvp.Value)
            {
                var fs = src as FileStream;
                if (fs != null)
                    path = fs.Name;
 
                var ae = _archive.CreateEntry(kvp.Key);
                using (var dst = ae.Open())
                {
                    src.Position = 0;
                    src.CopyTo(dst);
                }
            }
 
            if (!string.IsNullOrEmpty(path))
                File.Delete(path);
        }
    }
 
    /// <summary>
    /// Commit the writing of the repository. This signals successful completion of the write.
    /// </summary>
    public void Commit()
    {
        ExceptionContext.Check(!Disposed);
        ExceptionContext.AssertValue(_closed);
 
        DisposeAllEntries();
        Flush();
        Dispose(true);
    }
 
    private static string GetProductVersion()
    {
        var assembly = typeof(RepositoryWriter).Assembly;
 
        var assemblyInternationalVersionAttribute = assembly.CustomAttributes.FirstOrDefault(a =>
                a.AttributeType == typeof(AssemblyInformationalVersionAttribute));
 
        if (assemblyInternationalVersionAttribute == null)
        {
            throw new ApplicationException($"Cannot determine product version from assembly {assembly.FullName}.");
        }
 
        return assemblyInternationalVersionAttribute.ConstructorArguments
            .First()
            .Value
            .ToString();
    }
}
 
[BestFriend]
internal sealed class RepositoryReader : Repository
{
    private readonly ZipArchive _archive;
 
    // Maps from a normalized path to the entry in the _archive. This is needed since
    // a zip might use / or \ for directory separation.
    private readonly Dictionary<string, ZipArchiveEntry> _entries;
 
    public static RepositoryReader Open(Stream stream, IExceptionContext ectx = null, bool useFileSystem = true)
    {
        Contracts.CheckValueOrNull(ectx);
        ectx.CheckValue(stream, nameof(stream));
        return new RepositoryReader(stream, ectx, useFileSystem);
    }
 
    private RepositoryReader(Stream stream, IExceptionContext ectx, bool useFileSystem)
        : base(useFileSystem, ectx)
    {
        try
        {
            _archive = new ZipArchive(stream, ZipArchiveMode.Read, true);
        }
        catch (Exception ex)
        {
            throw ExceptionContext.ExceptDecode(ex, "Failed to open a zip archive");
        }
 
        _entries = new Dictionary<string, ZipArchiveEntry>();
        foreach (var entry in _archive.Entries)
        {
            var path = NormalizeForArchiveEntry(entry.FullName);
            _entries[path] = entry;
        }
    }
 
    public Entry OpenEntry(string name)
    {
        return OpenEntry(null, name);
    }
 
    public Entry OpenEntry(string dir, string name)
    {
        var ent = OpenEntryOrNull(dir, name);
        if (ent != null)
            return ent;
 
        string pathEnt;
        string pathTemp;
        GetPath(out pathEnt, out pathTemp, dir, name, false);
        throw ExceptionContext.Except("Repository doesn't contain entry {0}", pathEnt);
    }
 
    public Entry OpenEntryOrNull(string name)
    {
        return OpenEntryOrNull(null, name);
    }
 
    public Entry OpenEntryOrNull(string dir, string name)
    {
        ExceptionContext.Check(!Disposed);
 
        string pathEnt;
        string pathTemp;
        GetPath(out pathEnt, out pathTemp, dir, name, false);
 
        ZipArchiveEntry entry;
        Stream stream;
        string pathAbs;
        string pathLower = pathEnt.ToLowerInvariant();
        if (PathMap.TryGetValue(pathLower, out pathAbs))
        {
            stream = new FileStream(pathAbs, FileMode.Open, FileAccess.Read, FileShare.Read);
        }
        else
        {
            if (!_entries.TryGetValue(pathEnt, out entry))
            {
                //Read old zip file that use backslash in filename
                var pathEntTmp = pathEnt.Replace("/", "\\");
                if (!_entries.TryGetValue(pathEntTmp, out entry))
                {
                    return null;
                }
            }
 
            if (pathTemp != null)
            {
                // Extract to a temporary file.
                Directory.CreateDirectory(Path.GetDirectoryName(pathTemp));
                entry.ExtractToFile(pathTemp);
                if (!PathMap.TryAdd(pathLower, pathTemp))
                    throw ExceptionContext.ExceptParam(nameof(name), "Duplicate entry: '{0}'", pathLower);
 
                stream = new FileStream(pathTemp, FileMode.Open, FileAccess.Read, FileShare.Read);
            }
            else
            {
                // Extract to a memory stream.
                ExceptionContext.CheckDecode(entry.Length < int.MaxValue, "Repository stream too large to read into memory");
                stream = new MemoryStream((int)entry.Length);
                using (var src = entry.Open())
                    src.CopyTo(stream);
                stream.Position = 0;
            }
        }
 
        return AddEntry(pathEnt, stream);
    }
 
    protected override void OnDispose(Entry ent)
    {
        ExceptionContext.AssertValue(ent);
        RemoveEntry(ent);
        ent.Stream.CloseEx();
    }
}