File: Utilities\ModelFileUtils.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Model
{
    using ColumnRole = RoleMappedSchema.ColumnRole;
    using Conditional = System.Diagnostics.ConditionalAttribute;
 
    /// <summary>
    /// This class provides utilities for loading components from the model file generated by MAML commands.
    /// </summary>
    [BestFriend]
    internal static class ModelFileUtils
    {
        public const string DirPredictor = "Predictor";
        public const string DirDataLoaderModel = "DataLoaderModel";
        public const string DirTransformerChain = TransformerChain.LoaderSignature;
        public const string SchemaEntryName = ModelOperationsCatalog.SchemaEntryName;
        // ResultsProcessor needs access to this constant.
        public const string DirTrainingInfo = "TrainingInfo";
 
        private const string RoleMappingFile = "RoleMapping.txt";
 
        /// <summary>
        /// Loads and returns the loader and transforms from the specified model stream.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="modelStream">The model stream.</param>
        /// <param name="files">The data source to initialize the loader with.</param>
        /// <param name="extractInnerPipe">Whether to extract the transforms and loader from the wrapped CompositeDataLoader.</param>
        /// <returns>The created data view.</returns>
        public static IDataView LoadPipeline(IHostEnvironment env, Stream modelStream, IMultiStreamSource files, bool extractInnerPipe = false)
        {
            // REVIEW: Should not duplicate loading loader/transforms code. This method should call LoadLoader.
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(modelStream, nameof(modelStream));
            env.CheckValue(files, nameof(files));
            using (var rep = RepositoryReader.Open(modelStream, env))
            {
                return LoadPipeline(env, rep, files, extractInnerPipe);
            }
        }
 
        /// <summary>
        /// Loads and returns the loader and transforms from the specified repository reader.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="rep">The repository reader.</param>
        /// <param name="files">The data source to initialize the loader with.</param>
        /// <param name="extractInnerPipe">Whether to extract the transforms and loader from the wrapped CompositeDataLoader.</param>
        /// <returns>The created data view.</returns>
        public static IDataView LoadPipeline(IHostEnvironment env, RepositoryReader rep, IMultiStreamSource files, bool extractInnerPipe = false)
        {
            // REVIEW: Should not duplicate loading loader/transforms code. This method should call LoadLoader.
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(rep, nameof(rep));
            env.CheckValue(files, nameof(files));
 
            var entry = rep.OpenEntryOrNull(SchemaEntryName);
            if (entry != null)
            {
                var loader = new BinaryLoader(env, new BinaryLoader.Arguments(), entry.Stream);
                ModelLoadContext.LoadModel<ITransformer, SignatureLoadModel>(env, out var transformerChain, rep, DirTransformerChain);
                return transformerChain.Transform(loader);
            }
 
            using (var ent = rep.OpenEntry(DirDataLoaderModel, ModelLoadContext.ModelStreamName))
            {
                ILegacyDataLoader loader;
                env.Assert(ent.Stream.Position == 0);
                ModelLoadContext.LoadModel<ILegacyDataLoader, SignatureLoadDataLoader>(env, out loader, rep, ent, DirDataLoaderModel, files);
                IDataView result = loader;
                if (extractInnerPipe)
                {
                    var cdl = loader as LegacyCompositeDataLoader;
                    result = cdl == null ? loader : cdl.View;
                }
 
                return result;
            }
        }
 
        /// <summary>
        /// Loads all transforms from the model stream, applies them sequentially to the provided data, and returns
        /// the resulting data. If there are no transforms in the stream, or if there's no DataLoader stream at all
        /// (this can happen if the model is produced by old TL), returns the source data.
        /// If the DataLoader stream is invalid, throws.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="data">The starting data view.</param>
        /// <param name="modelStream">The model stream.</param>
        /// <returns>The resulting data view.</returns>
        public static IDataView LoadTransforms(IHostEnvironment env, IDataView data, Stream modelStream)
        {
            // REVIEW: Consolidate with LoadTransformChain in DataDiagnosticsCommand.
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckValue(modelStream, nameof(modelStream));
            using (var rep = RepositoryReader.Open(modelStream, env))
            {
                return LoadTransforms(env, data, rep);
            }
        }
 
        /// <summary>
        /// Loads all transforms from the model stream, applies them sequentially to the provided data, and returns
        /// the resulting data. If there are no transforms in the stream, or if there's no DataLoader stream at all
        /// (this can happen if the model is produced by old TL), returns the source data.
        /// If the DataLoader stream is invalid, throws.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="data">The starting data view.</param>
        /// <param name="rep">The repository reader.</param>
        /// <returns>The resulting data view.</returns>
        public static IDataView LoadTransforms(IHostEnvironment env, IDataView data, RepositoryReader rep)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckValue(rep, nameof(rep));
            using (var ent = rep.OpenEntryOrNull(DirDataLoaderModel, ModelLoadContext.ModelStreamName))
            {
                if (ent == null)
                    return data;
                var ctx = new ModelLoadContext(rep, ent, DirDataLoaderModel);
                return LegacyCompositeDataLoader.LoadSelectedTransforms(ctx, data, env, x => true);
            }
        }
 
        /// <summary>
        /// Loads a predictor from the model stream. Returns null iff there's no predictor.
        /// </summary>
        public static IPredictor LoadPredictorOrNull(IHostEnvironment env, Stream modelStream)
        {
            Contracts.CheckValue(env, nameof(env));
            Contracts.CheckValue(modelStream, nameof(modelStream));
            using (var rep = RepositoryReader.Open(modelStream, env))
                return LoadPredictorOrNull(env, rep);
        }
 
        /// <summary>
        /// Loads a predictor from the repository. Returns null iff there's no predictor.
        /// </summary>
        public static IPredictor LoadPredictorOrNull(IHostEnvironment env, RepositoryReader rep)
        {
            Contracts.CheckValue(env, nameof(env));
            Contracts.CheckValue(rep, nameof(rep));
            IPredictor predictor;
            ModelLoadContext.LoadModelOrNull<IPredictor, SignatureLoadModel>(env, out predictor, rep, DirPredictor);
            return predictor;
        }
 
        /// <summary>
        /// Given a repository, returns the save context for saving the data loader model.
        /// </summary>
        public static ModelSaveContext GetDataModelSavingContext(RepositoryWriter rep)
        {
            Contracts.CheckValue(rep, nameof(rep));
            return new ModelSaveContext(rep, DirDataLoaderModel, ModelLoadContext.ModelStreamName);
        }
 
        /// <summary>
        /// Loads data view (loader and transforms) from <paramref name="rep"/> if <paramref name="loadTransforms"/> is set to true,
        /// otherwise loads loader only.
        /// </summary>
        public static ILegacyDataLoader LoadLoader(IHostEnvironment env, RepositoryReader rep, IMultiStreamSource files, bool loadTransforms)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(rep, nameof(rep));
            env.CheckValue(files, nameof(files));
 
            ILegacyDataLoader loader;
 
            // If loadTransforms is false, load the loader only, not the transforms.
            Repository.Entry ent = null;
            string dir = "";
            if (!loadTransforms)
                ent = rep.OpenEntryOrNull(dir = Path.Combine(DirDataLoaderModel, "Loader"), ModelLoadContext.ModelStreamName);
 
            if (ent == null) // either loadTransforms is true, or it's not a composite loader
                ent = rep.OpenEntry(dir = DirDataLoaderModel, ModelLoadContext.ModelStreamName);
 
            env.CheckDecode(ent != null, "Loader is not found.");
            env.AssertNonEmpty(dir);
            using (ent)
            {
                env.Assert(ent.Stream.Position == 0);
                ModelLoadContext.LoadModel<ILegacyDataLoader, SignatureLoadDataLoader>(env, out loader, rep, ent, dir, files);
            }
            return loader;
        }
 
        /// <summary>
        /// REVIEW: consider adding an overload that returns <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/>
        /// Loads optionally feature names from the repository directory.
        /// Returns false iff no stream was found for feature names, iff result is set to null.
        /// </summary>
        public static bool TryLoadFeatureNames(out FeatureNameCollection featureNames, RepositoryReader rep)
        {
            Contracts.CheckValue(rep, nameof(rep));
 
            using (var ent = rep.OpenEntryOrNull(ModelFileUtils.DirTrainingInfo, "FeatureNames.bin"))
            {
                if (ent != null)
                {
                    using (var ctx = new ModelLoadContext(rep, ent, ModelFileUtils.DirTrainingInfo))
                    {
                        featureNames = FeatureNameCollection.Create(ctx);
                        return true;
                    }
                }
            }
 
            featureNames = null;
            return false;
        }
 
        /// <summary>
        /// Save schema associations of role/column-name in <paramref name="rep"/>.
        /// </summary>
        internal static void SaveRoleMappings(IHostEnvironment env, IChannel ch, RoleMappedSchema schema, RepositoryWriter rep)
        {
            // REVIEW: Should we also save this stuff, for instance, in some portion of the
            // score command or transform?
            Contracts.AssertValue(env);
            env.AssertValue(ch);
            ch.AssertValue(schema);
 
            ArrayDataViewBuilder builder = new ArrayDataViewBuilder(env);
 
            List<string> rolesList = new List<string>();
            List<string> columnNamesList = new List<string>();
            // OrderBy is stable, so there is no danger in it "reordering" columns
            // when a role is filled by multiple columns.
            foreach (var role in schema.GetColumnRoleNames().OrderBy(r => r.Key.Value))
            {
                rolesList.Add(role.Key.Value);
                columnNamesList.Add(role.Value);
            }
            builder.AddColumn("Role", rolesList.ToArray());
            builder.AddColumn("Column", columnNamesList.ToArray());
 
            using (var entry = rep.CreateEntry(DirTrainingInfo, RoleMappingFile))
            {
                // REVIEW: It seems very important that we have the role mappings
                // be easily human interpretable and even manipulable, but relying on the
                // text saver/loader means that special characters like '\n' won't be reinterpretable.
                // On the other hand, no one is such a big lunatic that they will actually
                // ever go ahead and do something so stupid as that.
                var saver = new TextSaver(env, new TextSaver.Arguments() { Dense = true, Silent = true });
                var view = builder.GetDataView();
                saver.SaveData(entry.Stream, view, Utils.GetIdentityPermutation(view.Schema.Count));
            }
        }
 
        /// <summary>
        /// Return role/column-name pairs loaded from <paramref name="modelStream"/>.
        /// </summary>
        public static IEnumerable<KeyValuePair<ColumnRole, string>> LoadRoleMappingsOrNull(IHostEnvironment env, Stream modelStream)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(modelStream, nameof(modelStream));
            using (var rep = RepositoryReader.Open(modelStream, env))
            {
                return LoadRoleMappingsOrNull(env, rep);
            }
        }
 
        /// <summary>
        /// Return role/column-name pairs loaded from a repository.
        /// </summary>
        public static IEnumerable<KeyValuePair<ColumnRole, string>> LoadRoleMappingsOrNull(IHostEnvironment env, RepositoryReader rep)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("RoleMappingUtils");
 
            var list = new List<KeyValuePair<string, string>>();
 
            var entry = rep.OpenEntryOrNull(DirTrainingInfo, RoleMappingFile);
            if (entry == null)
                return null;
            entry.Dispose();
 
            using (var ch = h.Start("Loading role mappings"))
            {
                // REVIEW: Should really validate the schema here, and consider
                // ignoring this stream if it isn't as expected.
                var repoStreamWrapper = new RepositoryStreamWrapper(rep, DirTrainingInfo, RoleMappingFile);
                var loader = new TextLoader(env, dataSample: repoStreamWrapper).Load(repoStreamWrapper);
 
                using (var cursor = loader.GetRowCursorForAllColumns())
                {
                    var roleGetter = cursor.GetGetter<ReadOnlyMemory<char>>(cursor.Schema[0]);
                    var colGetter = cursor.GetGetter<ReadOnlyMemory<char>>(cursor.Schema[1]);
                    var role = default(ReadOnlyMemory<char>);
                    var col = default(ReadOnlyMemory<char>);
                    while (cursor.MoveNext())
                    {
                        roleGetter(ref role);
                        colGetter(ref col);
                        string roleStr = role.ToString();
                        string colStr = col.ToString();
 
                        h.CheckDecode(!string.IsNullOrWhiteSpace(roleStr), "Role name must not be empty");
                        h.CheckDecode(!string.IsNullOrWhiteSpace(colStr), "Column name must not be empty");
                        list.Add(new KeyValuePair<string, string>(roleStr, colStr));
                    }
                }
            }
 
            return TrainUtils.CheckAndGenerateCustomColumns(env, list.ToArray());
        }
 
        /// <summary>
        /// Returns the <see cref="RoleMappedSchema"/> from a model stream, or <c>null</c> if there were no
        /// role mappings present.
        /// </summary>
        public static RoleMappedSchema LoadRoleMappedSchemaOrNull(IHostEnvironment env, Stream modelStream)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(modelStream, nameof(modelStream));
            using (var rep = RepositoryReader.Open(modelStream, env))
            {
                return LoadRoleMappedSchemaOrNull(env, rep);
            }
        }
 
        /// <summary>
        /// Returns the <see cref="RoleMappedSchema"/> from a repository, or <c>null</c> if there were no
        /// role mappings present.
        /// </summary>
        public static RoleMappedSchema LoadRoleMappedSchemaOrNull(IHostEnvironment env, RepositoryReader rep)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("RoleMappingUtils");
 
            var roleMappings = ModelFileUtils.LoadRoleMappingsOrNull(env, rep);
            if (roleMappings == null)
                return null;
            var pipe = ModelFileUtils.LoadLoader(h, rep, new MultiFileSource(null), loadTransforms: true);
            return new RoleMappedSchema(pipe.Schema, roleMappings);
        }
 
        /// <summary>
        /// The RepositoryStreamWrapper is a IMultiStreamSource wrapper of a Stream object in a repository.
        /// It is used to deserialize RoleMappings.txt from a model zip file.
        /// </summary>
        private sealed class RepositoryStreamWrapper : IMultiStreamSource
        {
            private readonly RepositoryReader _repository;
            private readonly string _directory;
            private readonly string _filename;
 
            public RepositoryStreamWrapper(RepositoryReader repository, string directory, string filename)
            {
                Contracts.CheckValue(repository, nameof(repository));
                Contracts.CheckNonWhiteSpace(directory, nameof(directory));
                Contracts.CheckNonWhiteSpace(filename, nameof(filename));
 
                _repository = repository;
                _directory = directory;
                _filename = filename;
            }
 
            public int Count { get { return 1; } }
 
            public string GetPathOrNull(int index)
            {
                Contracts.Check(index == 0);
                return null;
            }
 
            public Stream Open(int index)
            {
                Contracts.Assert(index == 0);
                var ent = _repository.OpenEntryOrNull(_directory, _filename);
                if (ent == null)
                    throw Contracts.Except($"File '{_filename}' is missing from the repository");
                return new EntryStream(ent);
            }
 
            public TextReader OpenTextReader(int index) { return new StreamReader(Open(index)); }
 
            /// <summary>
            /// A custom entry stream wrapper that includes custom dispose logic for disposing the entry
            /// when the stream is disposed.
            /// </summary>
            private sealed class EntryStream : Stream
            {
                private bool _disposed;
 
                private readonly Repository.Entry _entry;
 
                public override bool CanRead
                {
                    get
                    {
                        AssertValid();
                        return _entry.Stream.CanRead;
                    }
                }
 
                public override bool CanSeek
                {
                    get
                    {
                        AssertValid();
                        return _entry.Stream.CanSeek;
                    }
                }
 
                public override bool CanWrite
                {
                    get
                    {
                        AssertValid();
                        return _entry.Stream.CanWrite;
                    }
                }
 
                public override long Length
                {
                    get
                    {
                        AssertValid();
                        return _entry.Stream.Length;
                    }
                }
 
                public override long Position
                {
                    get
                    {
                        AssertValid();
                        return _entry.Stream.Position;
                    }
 
                    set
                    {
                        AssertValid();
                        _entry.Stream.Position = value;
                    }
                }
 
                public EntryStream(Repository.Entry entry)
                {
                    Contracts.CheckValue(entry, nameof(entry));
                    Contracts.CheckValue(entry.Stream, nameof(entry.Stream));
                    _entry = entry;
                }
 
                public override void Flush()
                {
                    AssertValid();
                    _entry.Stream.Flush();
                }
 
                public override long Seek(long offset, SeekOrigin origin)
                {
                    AssertValid();
                    return _entry.Stream.Seek(offset, origin);
                }
 
                public override void SetLength(long value)
                {
                    AssertValid();
                    _entry.Stream.SetLength(value);
                }
 
                public override int Read(byte[] buffer, int offset, int count)
                {
                    AssertValid();
                    return _entry.Stream.Read(buffer, offset, count);
                }
 
                public override void Write(byte[] buffer, int offset, int count)
                {
                    AssertValid();
                    _entry.Stream.Write(buffer, offset, count);
                }
 
                protected override void Dispose(bool disposing)
                {
                    if (!_disposed)
                    {
                        AssertValid();
                        _entry.Dispose();
                        _disposed = true;
                    }
                    base.Dispose(disposing);
                }
 
                [Conditional("DEBUG")]
                private void AssertValid()
                {
#if DEBUG
                    Contracts.AssertValue(_entry);
                    Contracts.AssertValue(_entry.Stream);
#endif
                }
            }
        }
    }
}