File: Model\Pfa\SavePfaCommand.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
 
[assembly: LoadableClass(SavePfaCommand.Summary, typeof(SavePfaCommand), typeof(SavePfaCommand.Arguments), typeof(SignatureCommand),
    "Save PFA", "SavePfa", DocName = "command/SavePfa.md")]
 
namespace Microsoft.ML.Model.Pfa
{
    internal sealed class SavePfaCommand : DataCommand.ImplBase<SavePfaCommand.Arguments>
    {
        public const string Summary = "Given a data model, write out the corresponding PFA.";
        public const string LoadName = "SavePfa";
 
        public sealed class Arguments : DataCommand.ArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "The path to write the output PFA too. Leave empty for stdout.", SortOrder = 1)]
            public string Pfa;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The 'name' property in the output PFA program. By default this will be the extension-less name ", NullName = "<Auto>", SortOrder = 3)]
            public string Name;
 
            // This option is a bit problematic, since you can only really set something if it has the same exact type.
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should allow set operations.", ShortName = "set", SortOrder = 3, Hide = true)]
            public bool AllowSet;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Comma delimited list of input column names to drop", ShortName = "idrop", SortOrder = 4)]
            public string InputsToDrop;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Comma delimited list of output column names to drop", ShortName = "odrop", SortOrder = 5)]
            public string OutputsToDrop;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the inputs should also map to the outputs.", ShortName = "input", SortOrder = 6)]
            public bool KeepInput;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 7)]
            public bool? LoadPredictor;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Format option for the JSON exporter.", ShortName = "format", SortOrder = 8)]
            public Formatting Formatting = Formatting.Indented;
        }
 
        private readonly string _outputModelPath;
        private readonly string _name;
        private readonly bool _allowSet;
        private readonly bool _keepInput;
        private readonly bool? _loadPredictor;
        private readonly HashSet<string> _inputsToDrop;
        private readonly HashSet<string> _outputsToDrop;
        private readonly Formatting _formatting;
 
        public SavePfaCommand(IHostEnvironment env, Arguments args)
                : base(env, args, LoadName)
        {
            Host.CheckValue(args, nameof(args));
            Utils.CheckOptionalUserDirectory(args.Pfa, nameof(args.Pfa));
            _outputModelPath = string.IsNullOrWhiteSpace(args.Pfa) ? null : args.Pfa;
            if (args.Name == null && _outputModelPath != null)
                _name = Path.GetFileNameWithoutExtension(_outputModelPath);
            else if (!string.IsNullOrWhiteSpace(args.Name))
                _name = args.Name;
            _allowSet = args.AllowSet;
            _keepInput = args.KeepInput;
            _loadPredictor = args.LoadPredictor;
 
            _inputsToDrop = CreateDropMap(args.InputsToDrop);
            _outputsToDrop = CreateDropMap(args.OutputsToDrop);
 
            Host.CheckUserArg(Enum.IsDefined(typeof(Formatting), args.Formatting), nameof(args.Formatting), "Undefined value");
            _formatting = args.Formatting;
        }
 
        private static HashSet<string> CreateDropMap(string toDrop)
        {
            if (string.IsNullOrWhiteSpace(toDrop))
                return new HashSet<string>();
            return new HashSet<string>(toDrop.Split(','));
        }
 
        public override void Run()
        {
            using (var ch = Host.Start("Run"))
            {
                Run(ch);
            }
        }
 
        private void GetPipe(IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList<ITransformCanSavePfa> transforms)
        {
            Host.AssertValue(end);
            source = trueEnd = (end as LegacyCompositeDataLoader)?.View ?? end;
            IDataTransform transform = source as IDataTransform;
            transforms = new LinkedList<ITransformCanSavePfa>();
            while (transform != null)
            {
                ITransformCanSavePfa pfaTransform = transform as ITransformCanSavePfa;
                if (pfaTransform == null || !pfaTransform.CanSavePfa)
                {
                    ch.Warning("Had to stop walkback of pipeline at {0} since it cannot save itself as PFA", transform.GetType().Name);
                    return;
                }
                transforms.AddFirst(pfaTransform);
                transform = (source = transform.Source) as IDataTransform;
            }
            Host.AssertValue(source);
        }
 
        private void Run(IChannel ch)
        {
            ILegacyDataLoader loader;
            IPredictor rawPred;
            RoleMappedSchema trainSchema;
 
            if (string.IsNullOrEmpty(ImplOptions.InputModelFile))
            {
                loader = CreateLoader();
                rawPred = null;
                trainSchema = null;
                Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor),
                    "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specified.");
            }
            else
                LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
 
            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList<ITransformCanSavePfa> transforms;
            GetPipe(ch, loader, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);
 
            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                    data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = new RoleMappedData(end, DefaultColumnNames.Label,
                        DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
                }
 
                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scorePfa = scorePipe as ITransformCanSavePfa;
                if (scorePfa?.CanSavePfa == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scorePfa);
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                        nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as PFA.");
                    ch.Warning("We do not know how to save the predictor as PFA. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                    nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }
 
            var ctx = new BoundPfaContext(Host, source.Schema, _inputsToDrop, allowSet: _allowSet);
            foreach (var trans in transforms)
            {
                Host.Assert(trans.CanSavePfa);
                trans.SaveAsPfa(ctx);
            }
 
            var toExport = new List<string>();
            for (int i = 0; i < end.Schema.Count; ++i)
            {
                if (end.Schema[i].IsHidden)
                    continue;
                var name = end.Schema[i].Name;
                if (_outputsToDrop.Contains(name))
                    continue;
                if (!ctx.IsInput(name) || _keepInput)
                    toExport.Add(name);
            }
            JObject pfaDoc = ctx.Finalize(end.Schema, toExport.ToArray());
            if (_name != null)
                pfaDoc["name"] = _name;
 
            if (_outputModelPath == null)
                ch.Info(MessageSensitivity.Schema, pfaDoc.ToString(_formatting));
            else
            {
                using (var file = Host.CreateOutputFile(_outputModelPath))
                using (var stream = file.CreateWriteStream())
                using (var writer = new StreamWriter(stream))
                    writer.Write(pfaDoc.ToString(_formatting));
            }
 
            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
            {
                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, ImplOptions.OutputModelFile);
            }
        }
    }
}