File: LoadTransform.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(LoadTransform.Summary, typeof(IDataTransform), typeof(LoadTransform), typeof(LoadTransform.Arguments), typeof(SignatureDataTransform),
    "Load Transform", "LoadTransform", "Load")]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// Load specific transforms from the specified model file. Allows one to 'cherry pick' transforms from
    /// a serialized chain, or to apply a pre-trained transform to a different (but still compatible) data view.
    /// </summary>
    internal static class LoadTransform
    {
        public class Arguments
        {
            // REVIEW: make it not required, and make commands fill in the missing model file with the default
            // input model file. This requires some hacking in DataDiagnosticCommand.
            [Argument(ArgumentType.Required, HelpText = "Model file to load the transforms from", ShortName = "in",
                SortOrder = 1, IsInputFileName = true)]
            public string ModelFile;
 
            [Argument(ArgumentType.Multiple, HelpText = "The tags (comma-separated) to be loaded (or omitted, if " + nameof(Complement) + "+)",
                Name = "Tag", SortOrder = 2)]
            public string[] Tags;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to load all transforms except those marked by tags", ShortName = "comp", SortOrder = 3)]
            public bool Complement = false;
        }
 
        internal const string Summary = "Loads specified transforms from the model file and applies them to current data.";
 
        /// <summary>
        /// A helper method to create <see cref="LoadTransform"/> for public facing API.
        /// </summary>
        /// <param name="env">Host Environment.</param>
        /// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
        /// <param name="modelFile">Model file to load the transforms from.</param>
        /// <param name="tag">The tags (comma-separated) to be loaded (or omitted, if complement is true).</param>
        /// <param name="complement">Whether to load all transforms except those marked by tags.</param>
        public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] tag, bool complement = false)
        {
            var args = new Arguments()
            {
                ModelFile = modelFile,
                Tags = tag,
                Complement = complement
            };
            return Create(env, args, input);
        }
 
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("LoadTransform");
            h.CheckValue(args, nameof(args));
            h.CheckValue(input, nameof(input));
            h.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile), "File does not exist");
 
            IDataView currentView;
 
            // If there are no 'tag' parameters, we load everything, regardless of 'comp'.
            bool complement = args.Complement || Utils.Size(args.Tags) == 0;
            var allTags = new HashSet<string>();
            for (int i = 0; i < Utils.Size(args.Tags); i++)
            {
                var curList = args.Tags[i];
                if (string.IsNullOrWhiteSpace(curList))
                    continue;
 
                foreach (var tag in curList.Split(','))
                {
                    if (!string.IsNullOrWhiteSpace(tag))
                        allTags.Add(tag.ToLower());
                }
            }
 
            Func<string, bool> predicate =
                tag =>
                {
                    bool found = allTags.Contains(tag.ToLower());
                    return found == !complement;
                };
 
            using (var file = h.OpenInputFile(args.ModelFile))
            using (var strm = file.OpenReadStream())
            using (var rep = RepositoryReader.Open(strm, h))
            using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName))
            using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel))
            {
                currentView = LegacyCompositeDataLoader.LoadSelectedTransforms(ctx, input, h, predicate);
 
                if (currentView == input)
                {
                    // REVIEW: we are required to return an IDataTransform. Therefore, if we don't introduce a new transform
                    // on top of 'input', we must throw (since input may not be a data transform).
                    // We could of course introduce a 'no-op transform', or we could lift the requirement to always return an IDataTransform
                    // associated with SignatureDataTransform.
 
                    var criteria = string.Format(
                            complement
                                ? "transforms that don't have tags from the list: '{0}'"
                                : "transforms that have tags from the list: '{0}'",
                            string.Join(",", allTags));
                    throw h.ExceptUserArg(nameof(args.Tags), "No transforms were found that match the search criteria ({0})", criteria);
                }
            }
 
            h.Assert(currentView is IDataTransform);
            return (IDataTransform)currentView;
        }
    }
}