File: JsonUtils\ExecuteGraphCommand.cs
Web Access
Project: src\src\Microsoft.ML.EntryPoints\Microsoft.ML.EntryPoints.csproj (Microsoft.ML.EntryPoints)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
 
[assembly: LoadableClass(typeof(ExecuteGraphCommand), typeof(ExecuteGraphCommand.Arguments), typeof(SignatureCommand),
    "", "ExecGraph")]
 
namespace Microsoft.ML.EntryPoints
{
    internal sealed class ExecuteGraphCommand : ICommand
    {
        public sealed class Arguments
        {
            [DefaultArgument(ArgumentType.Required, HelpText = "Path to the graph to run")]
            public string GraphPath;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Random seed")]
            public int? Seed;
        }
 
        public const string LoadName = "ExecuteGraph";
 
        private readonly IHost _host;
        private readonly string _path;
 
        public ExecuteGraphCommand(IHostEnvironment env, Arguments args)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(LoadName, args.Seed);
            _host.CheckValue(args, nameof(args));
 
            _host.CheckUserArg(args.GraphPath != null && File.Exists(args.GraphPath), nameof(args.GraphPath), "Graph path does not exist");
            _path = args.GraphPath;
        }
 
        public void Run()
        {
            JObject graph;
            try
            {
                graph = JObject.Parse(File.ReadAllText(_path));
            }
            catch (JsonReaderException ex)
            {
                throw _host.Except(ex, "Failed to parse experiment graph: {0}", ex.Message);
            }
 
            var runner = new GraphRunner(_host, graph[FieldNames.Nodes] as JArray);
 
            // Setting inputs.
            var jInputs = graph[FieldNames.Inputs] as JObject;
            if (graph[FieldNames.Inputs] != null && jInputs == null)
                throw _host.Except("Unexpected value for '{0}': {1}", FieldNames.Inputs, graph[FieldNames.Inputs]);
            if (jInputs != null)
            {
                foreach (var kvp in jInputs)
                {
                    var path = kvp.Value as JValue;
                    if (path == null)
                        throw _host.Except("Invalid value for input: {0}", kvp.Value);
                    var varName = kvp.Key;
                    var type = runner.GetPortDataKind(varName);
 
                    SetInputFromPath(runner, varName, path.Value<string>(), type);
                }
            }
 
            runner.RunAll();
 
            // Reading outputs.
            var jOutputs = graph[FieldNames.Outputs] as JObject;
            if (jOutputs != null)
            {
                foreach (var kvp in jOutputs)
                {
                    var path = kvp.Value as JValue;
                    if (path == null)
                        throw _host.Except("Invalid value for output: {0}", kvp.Value);
                    var varName = kvp.Key;
                    var type = runner.GetPortDataKind(varName);
 
                    GetOutputToPath(runner, varName, path.Value<string>(), type);
                }
            }
        }
 
        public void SetInputFromPath(GraphRunner runner, string varName, string path, TlcModule.DataKind kind)
        {
            _host.CheckUserArg(runner != null, nameof(runner), "Provide a GraphRunner instance.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(varName), nameof(varName), "Specify a graph variable name.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(path), nameof(path), "Specify a valid file path.");
 
            switch (kind)
            {
                case TlcModule.DataKind.FileHandle:
                    var fh = new SimpleFileHandle(_host, path, false, false);
                    runner.SetInput(varName, fh);
                    break;
                case TlcModule.DataKind.DataView:
                    IDataView loader = new BinaryLoader(_host, new BinaryLoader.Arguments(), path);
                    runner.SetInput(varName, loader);
                    break;
                case TlcModule.DataKind.PredictorModel:
                    PredictorModelImpl pm;
                    using (var fs = File.OpenRead(path))
                        pm = new PredictorModelImpl(_host, fs);
                    runner.SetInput(varName, pm);
                    break;
                case TlcModule.DataKind.TransformModel:
                    TransformModelImpl tm;
                    using (var fs = File.OpenRead(path))
                        tm = new TransformModelImpl(_host, fs);
                    runner.SetInput(varName, tm);
                    break;
                default:
                    throw _host.Except("Port type {0} not supported", kind);
            }
        }
 
        public void GetOutputToPath(GraphRunner runner, string varName, string path, TlcModule.DataKind kind)
        {
            _host.CheckUserArg(runner != null, nameof(runner), "Provide a GraphRunner instance.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(varName), nameof(varName), "Specify a graph variable name.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(path), nameof(path), "Specify a valid file path.");
 
            string extension = Path.GetExtension(path);
            switch (kind)
            {
                case TlcModule.DataKind.FileHandle:
                    var fh = runner.GetOutput<IFileHandle>(varName);
                    throw _host.ExceptNotSupp("File handle outputs not yet supported.");
                case TlcModule.DataKind.DataView:
                    var idv = runner.GetOutput<IDataView>(varName);
                    if (idv != null)
                        SaveDataView(idv, path, extension);
                    else
                        using (var ch = _host.Start("Get outputs from executed graph"))
                        {
                            string msg = string.Format("Ignoring empty graph output (output name: {0}, type: {1}, expected output's file: {2})",
                                varName, nameof(idv), path + extension);
                            ch.Warning(msg);
                        }
                    break;
                case TlcModule.DataKind.PredictorModel:
                    var pm = runner.GetOutput<PredictorModel>(varName);
                    SavePredictorModel(pm, path);
                    break;
                case TlcModule.DataKind.TransformModel:
                    var tm = runner.GetOutput<TransformModel>(varName);
                    using (var handle = _host.CreateOutputFile(path))
                    using (var fs = handle.CreateWriteStream())
                        tm.Save(_host, fs);
                    break;
                case TlcModule.DataKind.Array:
                    string partialPath = path.Substring(0, path.Length - extension.Length);
 
                    var ipmArray = runner.GetOutputOrDefault<PredictorModel[]>(varName);
                    if (ipmArray != null && !ipmArray.GetType().IsValueType)
                    {
                        SaveArrayToFile(ipmArray, partialPath, extension);
                        break;
                    }
 
                    var idvArray = runner.GetOutputOrDefault<IDataView[]>(varName);
                    if (idvArray != null && !idvArray.GetType().IsValueType)
                    {
                        SaveArrayToFile(idvArray, partialPath, extension);
                        break;
                    }
                    goto default;
                default:
                    throw _host.Except("Port type {0} not supported", kind);
            }
 
        }
 
        private void SaveArrayToFile(PredictorModel[] array, string partialPath, string extension)
        {
            for (int i = 0; i < array.Length; i++)
            {
                string path = $"{partialPath}_{i}{extension}";
                SavePredictorModel(array[i], path);
            }
        }
 
        private void SaveArrayToFile(IDataView[] array, string partialPath, string extension)
        {
            for (int i = 0; i < array.Length; i++)
            {
                string path = $"{partialPath}_{i}{extension}";
                SaveDataView(array[i], path, extension);
            }
        }
 
        /// <summary>
        /// Saves the PredictorModel to the given path
        /// </summary>
        private void SavePredictorModel(PredictorModel pm, string path)
        {
            Contracts.CheckValue(pm, nameof(pm));
 
            using (var handle = _host.CreateOutputFile(path))
            using (var fs = handle.CreateWriteStream())
                pm.Save(_host, fs);
 
        }
 
        /// <summary>
        /// Saves the IDV to file based on its extension
        /// </summary>
        private void SaveDataView(IDataView idv, string path, string extension)
        {
            Contracts.CheckValue(idv, nameof(idv));
 
            IDataSaver saver;
            if (extension != ".csv" && extension != ".tsv" && extension != ".txt")
                saver = new BinarySaver(_host, new BinarySaver.Arguments());
            else
            {
                var saverArgs = new TextSaver.Arguments
                {
                    OutputHeader = true,
                    OutputSchema = false,
                    Separator = extension == ".csv" ? "comma" : "tab"
                };
                saver = new TextSaver(_host, saverArgs);
            }
            using (var handle = _host.CreateOutputFile(path))
            using (var fs = handle.CreateWriteStream())
            {
                saver.SaveData(fs, idv, Utils.GetIdentityPermutation(idv.Schema.Count)
                    .Where(x => saver.IsColumnSavable(idv.Schema[x].Type)).ToArray());
            }
        }
    }
 
}