File: Commands\DataCommand.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// This holds useful base classes for commands that ingest a primary dataset and deal with associated model files.
    /// </summary>
    [BestFriend]
    internal static class DataCommand
    {
        public abstract class ArgumentsBase
        {
            [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>", SignatureType = typeof(SignatureDataLoader))]
            public IComponentFactory<IMultiStreamSource, ILegacyDataLoader> Loader;
 
            [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)]
            public string DataFile;
 
            [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Model file to save", ShortName = "out")]
            public string OutputModelFile;
 
            [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)]
            public string InputModelFile;
 
            [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Load transforms from model file?", ShortName = "loadTrans", SortOrder = 91)]
            public bool? LoadTransforms;
 
            [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Random seed", ShortName = "seed", SortOrder = 101)]
            public int? RandomSeed;
 
            [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Verbose?", ShortName = "v", Hide = true)]
            public bool? Verbose;
 
            [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The web server to publish the RESTful API", Hide = true)]
            public ServerChannel.IServerFactory Server;
 
            // This is actually an advisory value. The implementations themselves are responsible for
            // determining what they consider appropriate, and the actual heuristics is a bit more
            // complex than just this.
            [Argument(ArgumentType.LastOccurrenceWins, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly,
                HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")]
            public int? Parallel;
 
            [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly,
                HelpText = "Transform", Name = "Transform", ShortName = "xf", SignatureType = typeof(SignatureDataTransform))]
            public KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>[] Transforms;
        }
 
        [BestFriend]
        internal abstract class ImplBase<TOptions> : ICommand
            where TOptions : ArgumentsBase
        {
            protected readonly IHost Host;
            protected readonly TOptions ImplOptions;
            private readonly ServerChannel.IServerFactory _serverFactory;
 
            protected ServerChannel.IServer InitServer(IChannel ch)
            {
                Host.CheckValue(ch, nameof(ch));
                ch.Check(Host != null, nameof(InitServer) + " called prematurely");
                return _serverFactory?.CreateComponent(Host, ch);
            }
 
            protected ImplBase(IHostEnvironment env, TOptions options, string name)
            {
                Contracts.CheckValue(env, nameof(env));
 
                // Note that env may be null here, which is OK since the CheckXxx methods are extension
                // methods designed to allow null.
                env.CheckValue(options, nameof(options));
 
                env.CheckUserArg(!(options.Parallel < 0), nameof(options.Parallel), "Degree of parallelism must be non-negative (or null)");
 
                // Capture the environment options from args.
                env = env.Register(name, options.RandomSeed, options.Verbose);
 
                env.CheckNonWhiteSpace(name, nameof(name));
                Host = env.Register(name);
                ImplOptions = options;
                _serverFactory = options.Server;
                Utils.CheckOptionalUserDirectory(options.OutputModelFile, nameof(options.OutputModelFile));
            }
 
            protected ImplBase(ImplBase<TOptions> impl, string name)
            {
                Contracts.CheckValue(impl, nameof(impl));
                Contracts.AssertValue(impl.Host);
                impl.Host.AssertValue(impl.ImplOptions);
                impl.Host.AssertValue(name);
                ImplOptions = impl.ImplOptions;
                Host = impl.Host.Register(name);
            }
 
            public abstract void Run();
 
            protected virtual void SendTelemetry(IChannelProvider prov)
            {
                using (var pipe = prov.StartPipe<TelemetryMessage>("TelemetryPipe"))
                {
                    SendTelemetryCore(pipe);
                }
            }
 
            protected void SendTelemetryComponent(IPipe<TelemetryMessage> pipe, IComponentFactory factory)
            {
                Host.AssertValue(pipe);
                Host.AssertValueOrNull(factory);
 
                if (factory is ICommandLineComponentFactory commandLineFactory)
                    pipe.Send(TelemetryMessage.CreateTrainer(commandLineFactory.Name, commandLineFactory.GetSettingsString()));
                else
                    pipe.Send(TelemetryMessage.CreateTrainer("Unknown", "Non-ICommandLineComponentFactory object"));
            }
 
            protected virtual void SendTelemetryCore(IPipe<TelemetryMessage> pipe)
            {
                Contracts.AssertValue(pipe);
 
                if (ImplOptions.Transforms != null)
                {
                    foreach (var transform in ImplOptions.Transforms)
                        SendTelemetryComponent(pipe, transform.Value);
                }
            }
 
            protected void SendTelemetryMetricCore(IChannelProvider prov, Dictionary<string, double> averageMetric)
            {
                using (var pipe = prov.StartPipe<TelemetryMessage>("TelemetryPipe"))
                {
                    Contracts.AssertValue(pipe);
 
                    if (averageMetric != null)
                    {
                        foreach (var pair in averageMetric)
                            pipe.Send(TelemetryMessage.CreateMetric(pair.Key, pair.Value, null));
                    }
                }
            }
 
            protected void SendTelemetryMetric(Dictionary<string, IDataView>[] metricValues)
            {
                Dictionary<string, double> averageMetric = new Dictionary<string, double>();
                foreach (Dictionary<string, IDataView> mValue in metricValues)
                {
                    var data = mValue.First().Value;
                    using (var cursor = data.GetRowCursorForAllColumns())
                    {
                        while (cursor.MoveNext())
                        {
                            for (int currentIndex = 0; currentIndex < cursor.Schema.Count; currentIndex++)
                            {
                                var nameOfMetric = "TLC_" + cursor.Schema[currentIndex].Name;
                                var type = cursor.Schema[currentIndex].Type;
                                if (type is NumberDataViewType)
                                {
                                    var getter = RowCursorUtils.GetGetterAs<double>(NumberDataViewType.Double, cursor, currentIndex);
                                    double metricValue = 0;
                                    getter(ref metricValue);
                                    if (averageMetric.ContainsKey(nameOfMetric))
                                    {
                                        averageMetric[nameOfMetric] += metricValue;
                                    }
                                    else
                                    {
                                        averageMetric.Add(nameOfMetric, metricValue);
                                    }
                                }
                                else
                                {
                                    if (averageMetric.ContainsKey(nameOfMetric))
                                    {
                                        averageMetric[nameOfMetric] = Double.NaN;
                                    }
                                    else
                                    {
                                        averageMetric.Add(nameOfMetric, Double.NaN);
                                    }
                                }
                            }
                        }
                    }
                }
                Dictionary<string, double> newAverageMetric = new Dictionary<string, double>();
                foreach (var pair in averageMetric)
                {
                    newAverageMetric.Add(pair.Key, pair.Value / metricValues.Length);
                }
                SendTelemetryMetricCore(Host, newAverageMetric);
            }
 
            protected ILegacyDataLoader LoadLoader(RepositoryReader rep, string path, bool loadTransforms)
            {
                return ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(path), loadTransforms);
            }
 
            protected void SaveLoader(ILegacyDataLoader loader, string path)
            {
                using (var file = Host.CreateOutputFile(path))
                    LoaderUtils.SaveLoader(loader, file, Host);
            }
 
            protected ILegacyDataLoader CreateAndSaveLoader(Func<IHostEnvironment, IMultiStreamSource, ILegacyDataLoader> defaultLoaderFactory = null)
            {
                var loader = CreateLoader(defaultLoaderFactory);
                if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
                {
                    using (var file = Host.CreateOutputFile(ImplOptions.OutputModelFile))
                        LoaderUtils.SaveLoader(loader, file, Host);
                }
                return loader;
            }
 
            /// <summary>
            /// Loads multiple artifacts of interest from the input model file, given the context
            /// established by the command line arguments.
            /// </summary>
            /// <param name="ch">The channel to which to provide output.</param>
            /// <param name="wantPredictor">Whether we want a predictor from the model file. If
            /// <c>false</c> we will not even attempt to load a predictor. If <c>null</c> we will
            /// load the predictor, if present. If <c>true</c> we will load the predictor, or fail
            /// noisily if we cannot.</param>
            /// <param name="predictor">The predictor in the model, or <c>null</c> if
            /// <paramref name="wantPredictor"/> was false, or <paramref name="wantPredictor"/> was
            /// <c>null</c> and no predictor was present.</param>
            /// <param name="wantTrainSchema">Whether we want the training schema. Unlike
            /// <paramref name="wantPredictor"/>, this has no "hard fail if not present" option. If
            /// this is <c>true</c>, it is still possible for <paramref name="trainSchema"/> to remain
            /// <c>null</c> if there were no role mappings, or pipeline.</param>
            /// <param name="trainSchema">The training schema if <paramref name="wantTrainSchema"/>
            /// is true, and there were role mappings stored in the model.</param>
            /// <param name="pipe">The data pipe constructed from the combination of the
            /// model and command line arguments.</param>
            protected void LoadModelObjects(
                IChannel ch,
                bool? wantPredictor, out IPredictor predictor,
                bool wantTrainSchema, out RoleMappedSchema trainSchema,
                out ILegacyDataLoader pipe)
            {
                // First handle the case where there is no input model file.
                // Everything must come from the command line.
 
                using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                using (var strm = file.OpenReadStream())
                using (var rep = RepositoryReader.Open(strm, Host))
                {
                    // First consider loading the predictor.
                    if (wantPredictor == false)
                        predictor = null;
                    else
                    {
                        ch.Trace("Loading predictor");
                        predictor = ModelFileUtils.LoadPredictorOrNull(Host, rep);
                        if (wantPredictor == true)
                            Host.Check(predictor != null, "Could not load predictor from model file");
                    }
 
                    // Next create the loader.
                    var loaderFactory = ImplOptions.Loader;
                    ILegacyDataLoader trainPipe = null;
                    if (loaderFactory != null)
                    {
                        // The loader is overridden from the command line.
                        pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(ImplOptions.DataFile));
                        if (ImplOptions.LoadTransforms == true)
                        {
                            Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile));
                            pipe = LoadTransformChain(pipe);
                        }
                    }
                    else
                    {
                        var loadTrans = ImplOptions.LoadTransforms ?? true;
                        pipe = LoadLoader(rep, ImplOptions.DataFile, loadTrans);
                        if (loadTrans)
                            trainPipe = pipe;
                    }
 
                    if (Utils.Size(ImplOptions.Transforms) > 0)
                        pipe = LegacyCompositeDataLoader.Create(Host, pipe, ImplOptions.Transforms);
 
                    // Next consider loading the training data's role mapped schema.
                    trainSchema = null;
                    if (wantTrainSchema)
                    {
                        // First try to get the role mappings.
                        var trainRoleMappings = ModelFileUtils.LoadRoleMappingsOrNull(Host, rep);
                        if (trainRoleMappings != null)
                        {
                            // Next create the training schema. In the event that the loaded pipeline happens
                            // to be the training pipe, we can just use that. If it differs, then we need to
                            // load the full pipeline from the model, relying upon the fact that all loaders
                            // can be loaded with no data at all, to get their schemas.
                            if (trainPipe == null)
                                trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true);
                            trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings);
                        }
                        // If the role mappings are null, an alternative would be to fail. However the idea
                        // is that the scorer should always still succeed, although perhaps with reduced
                        // functionality, even when the training schema is null, since not all versions of
                        // TLC models will have the role mappings preserved, I believe. And, we do want to
                        // maintain backwards compatibility.
                    }
                }
            }
 
            protected ILegacyDataLoader CreateLoader(Func<IHostEnvironment, IMultiStreamSource, ILegacyDataLoader> defaultLoaderFactory = null)
            {
                var loader = CreateRawLoader(defaultLoaderFactory);
                loader = CreateTransformChain(loader);
                return loader;
            }
 
            private ILegacyDataLoader CreateTransformChain(ILegacyDataLoader loader)
            {
                return LegacyCompositeDataLoader.Create(Host, loader, ImplOptions.Transforms);
            }
 
            protected ILegacyDataLoader CreateRawLoader(
                Func<IHostEnvironment, IMultiStreamSource, ILegacyDataLoader> defaultLoaderFactory = null,
                string dataFile = null)
            {
                if (string.IsNullOrWhiteSpace(dataFile))
                    dataFile = ImplOptions.DataFile;
 
                ILegacyDataLoader loader;
                if (!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile) && ImplOptions.Loader == null)
                {
                    // Load the loader from the data model.
                    using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                    using (var strm = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(strm, Host))
                        loader = LoadLoader(rep, dataFile, ImplOptions.LoadTransforms ?? true);
                }
                else
                {
                    // Either there is no input model file, or there is, but the loader is overridden.
                    IMultiStreamSource fileSource = new MultiFileSource(dataFile);
                    var loaderFactory = ImplOptions.Loader;
                    if (loaderFactory == null)
                    {
                        var ext = Path.GetExtension(dataFile);
                        var isText =
                            string.Equals(ext, ".txt", StringComparison.OrdinalIgnoreCase) ||
                            string.Equals(ext, ".tlc", StringComparison.OrdinalIgnoreCase);
                        var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
                        var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
 
                        return isText ? TextLoader.Create(Host, new TextLoader.Options(), fileSource) :
                               isBinary ? new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource) :
                               isTranspose ? new TransposeLoader(Host, new TransposeLoader.Arguments(), fileSource) :
                               defaultLoaderFactory != null ? defaultLoaderFactory(Host, fileSource) :
                               TextLoader.Create(Host, new TextLoader.Options(), fileSource);
                    }
                    else
                    {
                        loader = loaderFactory.CreateComponent(Host, fileSource);
                    }
 
                    if (ImplOptions.LoadTransforms == true)
                    {
                        Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile));
                        loader = LoadTransformChain(loader);
                    }
                }
                return loader;
            }
 
            private ILegacyDataLoader LoadTransformChain(ILegacyDataLoader srcData)
            {
                Host.Assert(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile));
 
                using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                using (var strm = file.OpenReadStream())
                using (var rep = RepositoryReader.Open(strm, Host))
                using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName))
                using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel))
                    return LegacyCompositeDataLoader.Create(Host, ctx, srcData, x => true);
            }
        }
    }
 
    [BestFriend]
    internal static class LoaderUtils
    {
        /// <summary>
        /// Saves <paramref name="loader"/> to the specified <paramref name="file"/>.
        /// </summary>
        public static void SaveLoader(ILegacyDataLoader loader, IFileHandle file, IHostEnvironment host)
        {
            Contracts.CheckValue(loader, nameof(loader));
            Contracts.CheckValue(file, nameof(file));
            Contracts.CheckParam(file.CanWrite, nameof(file), "Must be writable");
 
            using (var stream = file.CreateWriteStream())
            {
                SaveLoader(loader, stream, host);
            }
        }
 
        /// <summary>
        /// Saves <paramref name="loader"/> to the specified <paramref name="stream"/>.
        /// </summary>
        public static void SaveLoader(ILegacyDataLoader loader, Stream stream, IHostEnvironment host)
        {
            Contracts.CheckValue(loader, nameof(loader));
            Contracts.CheckValue(stream, nameof(stream));
            Contracts.CheckParam(stream.CanWrite, nameof(stream), "Must be writable");
 
            using (var rep = RepositoryWriter.CreateNew(stream, host))
            {
                ModelSaveContext.SaveModel(rep, loader, ModelFileUtils.DirDataLoaderModel);
                rep.Commit();
            }
        }
    }
}