File: Model\ModelOperationsCatalog.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML
{
    /// <summary>
    /// Class used by <see cref="MLContext"/> to save and load trained models.
    /// </summary>
    public sealed class ModelOperationsCatalog : IInternalCatalog
    {
        internal const string SchemaEntryName = "Schema";
 
        IHostEnvironment IInternalCatalog.Environment => _env;
        private readonly IHostEnvironment _env;
 
        internal ModelOperationsCatalog(IHostEnvironment env)
        {
            Contracts.AssertValue(env);
            _env = env;
        }
 
        /// <summary>
        /// Save a transformer model and the loader used to create its input data to the stream.
        /// </summary>
        /// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
        /// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/>
        /// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
        /// <param name="loader">The loader that was used to create data to train the model.</param>
        /// <param name="stream">A writeable, seekable stream to save to.</param>
        public void Save<TSource>(ITransformer model, IDataLoader<TSource> loader, Stream stream)
        {
            _env.CheckValue(loader, nameof(loader));
            _env.CheckValueOrNull(model);
            _env.CheckValue(stream, nameof(stream));
 
            // For the sake of consistency of this API specifically, when called upon we save any transformer
            // in a single element transformer chain.
            var chainedModel = model == null ? null : new TransformerChain<ITransformer>(model);
            var compositeLoader = new CompositeDataLoader<TSource, ITransformer>(loader, chainedModel);
 
            using (var rep = RepositoryWriter.CreateNew(stream, _env))
            {
                ModelSaveContext.SaveModel(rep, compositeLoader, null);
                rep.Commit();
            }
        }
 
        /// <summary>
        /// Save a transformer model and the loader used to create its input data to the file.
        /// </summary>
        /// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
        /// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/>
        /// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
        /// <param name="loader">The loader that was used to create data to train the model.</param>
        /// <param name="filePath">Path where model should be saved.</param>
        public void Save<TSource>(ITransformer model, IDataLoader<TSource> loader, string filePath)
        {
            _env.CheckValueOrNull(model);
            _env.CheckValue(loader, nameof(loader));
            _env.CheckNonEmpty(filePath, nameof(filePath));
 
            using (var stream = File.Create(filePath))
                Save(model, loader, stream);
        }
 
        /// <summary>
        /// Save a transformer model and the schema of the data that was used to train it to the stream.
        /// </summary>
        /// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
        /// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will
        /// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
        /// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param>
        /// <param name="stream">A writeable, seekable stream to save to.</param>
        public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
        {
            _env.CheckValueOrNull(model);
            _env.CheckValueOrNull(inputSchema);
            _env.CheckValue(stream, nameof(stream));
 
            using (var rep = RepositoryWriter.CreateNew(stream, _env))
            {
                ModelSaveContext.SaveModel(rep, model, CompositeDataLoader<object, ITransformer>.TransformerDirectory);
                SaveInputSchema(inputSchema, rep);
                rep.Commit();
            }
        }
 
        /// <summary>
        /// Save a transformer model and the schema of the data that was used to train it to the file.
        /// </summary>
        /// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
        /// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will
        /// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
        /// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param>
        /// <param name="filePath">Path where model should be saved.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[Save](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ModelOperations/SaveLoadModel.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public void Save(ITransformer model, DataViewSchema inputSchema, string filePath)
        {
            _env.CheckValueOrNull(model);
            _env.CheckValueOrNull(inputSchema);
            _env.CheckNonEmpty(filePath, nameof(filePath));
 
            using (var stream = File.Create(filePath))
                Save(model, inputSchema, stream);
        }
 
        private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
        {
            _env.AssertValueOrNull(inputSchema);
            _env.AssertValue(rep);
 
            if (inputSchema == null)
                return;
 
            using (var ch = _env.Start("Saving Schema"))
            {
                var entry = rep.CreateEntry(SchemaEntryName);
                var saver = new BinarySaver(_env, new BinarySaver.Arguments { Silent = true });
                DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(_env, inputSchema), entry.Stream, keepHidden: true);
            }
        }
 
        /// <summary>
        /// Load the model and its input schema from a stream.
        /// </summary>
        /// <param name="stream">A readable, seekable stream to load from.</param>
        /// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without
        /// any description of the input, there will be no input schema. In this case this can be <see langword="null"/>.</param>
        /// <returns>The loaded model.</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[Save](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ModelOperations/SaveLoadModel.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
        {
            _env.CheckValue(stream, nameof(stream));
 
            using (var rep = RepositoryReader.Open(stream, _env))
            {
                var entry = rep.OpenEntryOrNull(SchemaEntryName);
                if (entry != null)
                {
                    var loader = new BinaryLoader(_env, new BinaryLoader.Arguments(), entry.Stream);
                    inputSchema = loader.Schema;
                    ModelLoadContext.LoadModel<ITransformer, SignatureLoadModel>(_env, out var transformerChain, rep,
                        CompositeDataLoader<object, ITransformer>.TransformerDirectory);
                    return transformerChain;
                }
 
                ModelLoadContext.LoadModelOrNull<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out var dataLoader, rep, null);
                if (dataLoader == null)
                {
                    // Try to see if the model was saved without a loader or a schema.
                    if (ModelLoadContext.LoadModelOrNull<ITransformer, SignatureLoadModel>(_env, out var transformerChain, rep,
                        CompositeDataLoader<object, ITransformer>.TransformerDirectory))
                    {
                        inputSchema = null;
                        return transformerChain;
                    }
 
                    // Try to load from legacy model format.
                    try
                    {
                        var loader = ModelFileUtils.LoadLoader(_env, rep, new MultiFileSource(null), false);
                        inputSchema = loader.Schema;
                        return TransformerChain.LoadFromLegacy(_env, stream);
                    }
                    catch (Exception ex)
                    {
                        throw _env.Except(ex, "Could not load legacy format model");
                    }
                }
                var transformer = DecomposeLoader(ref dataLoader);
                inputSchema = dataLoader.GetOutputSchema();
                return transformer;
            }
        }
 
        /// <summary>
        /// Load the model and its input schema from a file.
        /// </summary>
        /// <param name="filePath">Path to a file where the model should be read from.</param>
        /// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without
        /// any description of the input, there will be no input schema. In this case this can be <see langword="null"/>.</param>
        /// <returns>The loaded model.</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[Save](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ModelOperations/SaveLoadModelFile.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public ITransformer Load(string filePath, out DataViewSchema inputSchema)
        {
            _env.CheckNonEmpty(filePath, nameof(filePath));
 
            using (var stream = File.OpenRead(filePath))
                return Load(stream, out inputSchema);
        }
 
        /// <summary>
        /// Given a loader, test try to "decompose" it into a source loader, and its transform if any.
        /// If necessary an empty chain will be created to stand in for the trivial transformation; it
        /// should never return <see langword="null"/>.
        /// </summary>
        private ITransformer DecomposeLoader(ref IDataLoader<IMultiStreamSource> loader)
        {
            _env.AssertValue(loader);
 
            if (loader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
            {
                loader = composite.Loader;
                var chain = composite.Transformer;
                // The save method corresponding to this load method encapsulates the input ITransformer
                // into a single-element transformer chain. If it is that sort, we guess that it is in fact
                // that sort, and so return it.
                var accessor = (ITransformerChainAccessor)chain;
                if (accessor.Transformers.Length == 1)
                    return accessor.Transformers[0];
                // If it is some other length than 1 due to, say, some legacy model saving, just return that
                // chain. Using the above API this is not possible, since the chain saved will always be of length
                // one, but older APIs behaved differently so we should retain flexibility with those schemes.
                // (Those schemes are BTW by no means incorrect, they just aren't what the API in this particular
                // class will specifically do.)
                return chain;
            }
            // Maybe we have no transformer stored. Rather than return null, we prefer to return the
            // empty "trivial" transformer chain.
            return new TransformerChain<ITransformer>();
        }
 
        /// <summary>
        /// Load a transformer model and a data loader model from a stream.
        /// </summary>
        /// <param name="stream">A readable, seekable stream to load from.</param>
        /// <param name="loader">The data loader from the model stream. Note that if there is no data loader,
        /// this method will throw an exception. The scenario where no loader is stored in the stream should
        /// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param>
        /// <returns>The transformer model from the model stream.</returns>
        public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader<IMultiStreamSource> loader)
        {
            _env.CheckValue(stream, nameof(stream));
 
            using (var rep = RepositoryReader.Open(stream))
            {
                try
                {
                    ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out loader, rep, null);
                    return DecomposeLoader(ref loader);
                }
                catch (Exception ex)
                {
                    throw _env.Except(ex, "Model does not contain an " + nameof(IDataLoader<IMultiStreamSource>) +
                        ". Perhaps this was saved with an " + nameof(DataViewSchema) + ", or even no information on its input at all. " +
                        "Consider using the " + nameof(Load) + " method instead.");
                }
            }
        }
 
        /// <summary>
        /// Load a transformer model and a data loader model from a file.
        /// </summary>
        /// <param name="filePath">Path to a file where the model should be read from.</param>
        /// <param name="loader">The data loader from the model stream. Note that if there is no data loader,
        /// this method will throw an exception. The scenario where no loader is stored in the stream should
        /// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param>
        /// <returns>The transformer model from the model file.</returns>
        public ITransformer LoadWithDataLoader(string filePath, out IDataLoader<IMultiStreamSource> loader)
        {
            _env.CheckNonEmpty(filePath, nameof(filePath));
 
            using (var stream = File.OpenRead(filePath))
                return LoadWithDataLoader(stream, out loader);
        }
 
        /// <summary>
        /// Create a prediction engine for one-time prediction (default usage).
        /// </summary>
        /// <typeparam name="TSrc">The class that defines the input data.</typeparam>
        /// <typeparam name="TDst">The class that defines the output data.</typeparam>
        /// <param name="transformer">The transformer to use for prediction.</param>
        /// <param name="ignoreMissingColumns">Whether to throw an exception if a column exists in
        /// <paramref name="outputSchemaDefinition"/> but the corresponding member doesn't exist in
        /// <typeparamref name="TDst"/>.</param>
        /// <param name="inputSchemaDefinition">Additional settings of the input schema.</param>
        /// <param name="outputSchemaDefinition">Additional settings of the output schema.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[Save](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ModelOperations/SaveLoadModel.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransformer transformer,
            bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
            where TSrc : class
            where TDst : class, new()
        {
            return transformer.CreatePredictionEngine<TSrc, TDst>(_env, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
        }
 
        /// <summary>
        /// Create a prediction engine for one-time prediction.
        /// It's mainly used in conjunction with <see cref="Load(Stream, out DataViewSchema)"/>,
        /// where input schema is extracted during loading the model.
        /// </summary>
        /// <typeparam name="TSrc">The class that defines the input data.</typeparam>
        /// <typeparam name="TDst">The class that defines the output data.</typeparam>
        /// <param name="transformer">The transformer to use for prediction.</param>
        /// <param name="inputSchema">Input schema.</param>
        public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransformer transformer, DataViewSchema inputSchema)
            where TSrc : class
            where TDst : class, new()
        {
            return transformer.CreatePredictionEngine<TSrc, TDst>(_env, false,
                DataViewConstructionUtils.GetSchemaDefinition<TSrc>(_env, inputSchema));
        }
 
        /// <summary>
        /// Create a prediction engine for one-time prediction.
        /// It's mainly used in conjunction with <see cref="Load(Stream, out DataViewSchema)"/>,
        /// where input schema is extracted during loading the model.
        /// </summary>
        /// <typeparam name="TSrc">The class that defines the input data.</typeparam>
        /// <typeparam name="TDst">The class that defines the output data.</typeparam>
        /// <param name="transformer">The transformer to use for prediction.</param>
        /// <param name="options">Advanced configuration options.</param>
        public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransformer transformer, PredictionEngineOptions options)
            where TSrc : class
            where TDst : class, new()
        {
            return transformer.CreatePredictionEngine<TSrc, TDst>(_env, options.IgnoreMissingColumns,
                options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnsTransformer);
        }
    }
}