File: Commands\SavePredictorCommand.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.IO;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
// REVIEW: Fix these namespaces.
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
using Microsoft.ML.Tools;
 
[assembly: LoadableClass(SavePredictorCommand.Summary, typeof(SavePredictorCommand), typeof(SavePredictorCommand.Arguments), typeof(SignatureCommand),
    "Save Predictor As", "SavePredictorAs", "SavePredictor", "SaveAs", "SaveModel")]
 
namespace Microsoft.ML.Tools
{
    internal sealed class SavePredictorCommand : ICommand
    {
        public sealed class Arguments
        {
#pragma warning disable 649 // never assigned
            [Argument(ArgumentType.AtMostOnce, HelpText = "Model file containing the predictor", ShortName = "in")]
            public string InputModelFile;
 
            // output a textual summary of the model (may not be complete information to recreate the model)
            [Argument(ArgumentType.AtMostOnce, HelpText = "File to save model summary", ShortName = "sum")]
            public string SummaryFile;
 
            // Output the model in human-readable text format
            [Argument(ArgumentType.AtMostOnce, HelpText = "File to save in text format", ShortName = "text")]
            public string TextFile;
 
            // Output the model in Bing INI format
            [Argument(ArgumentType.AtMostOnce, HelpText = "File to save in INI format", ShortName = "ini")]
            public string IniFile;
 
            // Output the model as C++/C# code
            [Argument(ArgumentType.AtMostOnce, HelpText = "File to save in C++ code", ShortName = "code")]
            public string CodeFile;
 
            // Output the model in binary format (for fast loading)
            [Argument(ArgumentType.AtMostOnce, HelpText = "File to save in binary format", ShortName = "bin")]
            public string BinaryFile;
#pragma warning restore 649 // never assigned
        }
 
        internal const string Summary = "Given a TLC model file with a predictor, we can output this same predictor in multiple export formats.";
 
        private readonly Arguments _args;
        private readonly IHost _host;
 
        public SavePredictorCommand(IHostEnvironment env, Arguments args)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register("SavePredictorCommand");
            _host.CheckValue(args, nameof(args));
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(args.InputModelFile), nameof(args.InputModelFile), "Must specify input model file");
 
            _args = args;
            CheckOutputDirectories();
        }
 
        private void CheckOutputDirectories()
        {
            Utils.CheckOptionalUserDirectory(_args.BinaryFile, nameof(_args.BinaryFile));
            Utils.CheckOptionalUserDirectory(_args.CodeFile, nameof(_args.CodeFile));
            Utils.CheckOptionalUserDirectory(_args.IniFile, nameof(_args.IniFile));
            Utils.CheckOptionalUserDirectory(_args.SummaryFile, nameof(_args.SummaryFile));
            Utils.CheckOptionalUserDirectory(_args.TextFile, nameof(_args.TextFile));
        }
 
        // REVIEW: Use the _env to emit messages instead of console.
        public void Run()
        {
            using (var file = _host.OpenInputFile(_args.InputModelFile))
            using (var strm = file.OpenReadStream())
            using (var binFile = CreateFile(_args.BinaryFile))
            using (var binStrm = CreateStrm(binFile))
            using (var sumFile = CreateFile(_args.SummaryFile))
            using (var sumStrm = CreateStrm(sumFile))
            using (var txtFile = CreateFile(_args.TextFile))
            using (var txtStrm = CreateStrm(txtFile))
            using (var iniFile = CreateFile(_args.IniFile))
            using (var iniStrm = CreateStrm(iniFile))
            using (var codFile = CreateFile(_args.CodeFile))
            using (var codStrm = CreateStrm(codFile))
                SavePredictorUtils.SavePredictor(_host, strm, binStrm, sumStrm, txtStrm, iniStrm, codStrm);
        }
 
        /// <summary>
        /// Create a file handle from path if it was not empty.
        /// </summary>
        private IFileHandle CreateFile(string path)
        {
            if (string.IsNullOrWhiteSpace(path))
                return null;
            return _host.CreateOutputFile(path);
        }
 
        /// <summary>
        /// Create the write stream from the file, if not null.
        /// </summary>
        private Stream CreateStrm(IFileHandle file)
        {
            if (file == null)
                return null;
            return file.CreateWriteStream();
        }
    }
 
    [BestFriend]
    internal static class SavePredictorUtils
    {
        public static void SavePredictor(IHostEnvironment env, Stream modelStream, Stream binaryModelStream = null, Stream summaryModelStream = null,
            Stream textModelStream = null, Stream iniModelStream = null, Stream codeModelStream = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(modelStream, nameof(modelStream));
 
            bool needNames = codeModelStream != null
                || iniModelStream != null
                || summaryModelStream != null
                || textModelStream != null;
 
            IPredictor predictor;
            RoleMappedSchema schema;
            LoadModel(env, modelStream, needNames, out predictor, out schema);
            using (var ch = env.Start("Saving predictor"))
            {
                SavePredictor(ch, predictor, schema, binaryModelStream, summaryModelStream, textModelStream,
                    iniModelStream, codeModelStream);
            }
        }
 
        public static void SavePredictor(IChannel ch, IPredictor predictor, RoleMappedSchema schema,
            Stream binaryModelStream = null, Stream summaryModelStream = null, Stream textModelStream = null,
            Stream iniModelStream = null, Stream codeModelStream = null)
        {
            Contracts.CheckValue(ch, nameof(ch));
            ch.CheckValue(predictor, nameof(predictor));
            ch.CheckValue(schema, nameof(schema));
 
            int count = 0;
            if (binaryModelStream != null)
            {
                ch.Info("Saving predictor as binary");
                using (var writer = new BinaryWriter(binaryModelStream, Encoding.UTF8, true))
                    PredictorUtils.SaveBinary(ch, predictor, writer);
                count++;
            }
 
            ch.CheckValue(schema, nameof(schema));
 
            if (summaryModelStream != null)
            {
                ch.Info("Saving predictor summary");
 
                using (StreamWriter writer = Utils.OpenWriter(summaryModelStream))
                    PredictorUtils.SaveSummary(ch, predictor, schema, writer);
                count++;
            }
 
            if (textModelStream != null)
            {
                ch.Info("Saving predictor as text");
                using (StreamWriter writer = Utils.OpenWriter(textModelStream))
                    PredictorUtils.SaveText(ch, predictor, schema, writer);
                count++;
            }
 
            if (iniModelStream != null)
            {
                ch.Info("Saving predictor as ini");
                using (StreamWriter writer = Utils.OpenWriter(iniModelStream))
                {
                    // Test if our predictor implements the more modern INI export interface.
                    // If it does not, use the old utility method.
                    ICanSaveInIniFormat saver = predictor as ICanSaveInIniFormat;
                    if (saver == null)
                        PredictorUtils.SaveIni(ch, predictor, schema, writer);
                    else
                        saver.SaveAsIni(writer, schema);
                }
                count++;
            }
 
            if (codeModelStream != null)
            {
                ch.Info("Saving predictor as code");
                using (StreamWriter writer = Utils.OpenWriter(codeModelStream))
                    PredictorUtils.SaveCode(ch, predictor, schema, writer);
                count++;
            }
 
            // Note that we don't check for this case up front so this command can be used to simply
            // check that the predictor is loadable.
            if (count == 0)
                ch.Info("No files saved. Must specify at least one output file.");
        }
 
        public static void LoadModel(IHostEnvironment env, Stream modelStream, bool loadNames, out IPredictor predictor, out RoleMappedSchema schema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(modelStream, nameof(modelStream));
 
            schema = null;
            using (var rep = RepositoryReader.Open(modelStream, env))
            {
                ModelLoadContext.LoadModel<IPredictor, SignatureLoadModel>(env, out predictor, rep, ModelFileUtils.DirPredictor);
 
                if (loadNames)
                {
                    var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, rep);
                    if (roles != null)
                    {
                        var emptyView = ModelFileUtils.LoadPipeline(env, rep, new MultiFileSource(null));
                        schema = new RoleMappedSchema(emptyView.Schema, roles, opt: true);
                    }
                    else
                    {
                        FeatureNameCollection names;
                        if (ModelFileUtils.TryLoadFeatureNames(out names, rep))
                            schema = names.Schema;
                    }
                }
            }
        }
    }
}