File: ModelOperations.cs
Web Access
Project: src\src\Microsoft.ML.EntryPoints\Microsoft.ML.EntryPoints.csproj (Microsoft.ML.EntryPoints)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(typeof(void), typeof(ModelOperations), null, typeof(SignatureEntryPointModule), "ModelOperations")]
 
namespace Microsoft.ML.EntryPoints
{
    internal static class ModelOperations
    {
        public sealed class CombineTransformModelsInput
        {
            [Argument(ArgumentType.Multiple, HelpText = "Input models", SortOrder = 1)]
            public TransformModel[] Models;
        }
 
        public sealed class CombineTransformModelsOutput
        {
            [TlcModule.Output(Desc = "Combined model", SortOrder = 1)]
            public TransformModel OutputModel;
        }
 
        public sealed class PredictorModelInput
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Transform model", SortOrder = 1)]
            public TransformModel[] TransformModels;
 
            [Argument(ArgumentType.Required, HelpText = "Predictor model", SortOrder = 2)]
            public PredictorModel PredictorModel;
        }
 
        public sealed class SimplePredictorModelInput
        {
            [Argument(ArgumentType.Required, HelpText = "Transform model", SortOrder = 1)]
            public TransformModel TransformModel;
 
            [Argument(ArgumentType.Required, HelpText = "Predictor model", SortOrder = 2)]
            public PredictorModel PredictorModel;
        }
 
        public sealed class PredictorModelOutput
        {
            [TlcModule.Output(Desc = "Predictor model", SortOrder = 1)]
            public PredictorModel PredictorModel;
        }
 
        public sealed class CombineOvaPredictorModelsInput : TrainerInputBaseWithWeight
        {
            [Argument(ArgumentType.Multiple, HelpText = "Input models", SortOrder = 1)]
            public PredictorModel[] ModelArray;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Use probabilities from learners instead of raw values.", SortOrder = 2)]
            public bool UseProbabilities = true;
        }
 
        public sealed class CombinePredictorModelsInput
        {
            [Argument(ArgumentType.Multiple, HelpText = "Input models", SortOrder = 1)]
            public PredictorModel[] Models;
        }
 
        public sealed class ApplyTransformModelInput : TransformInputBase
        {
            [Argument(ArgumentType.Required, HelpText = "Transform model", SortOrder = 2)]
            public TransformModel TransformModel;
        }
 
        public sealed class ApplyTransformModelOutput
        {
            [TlcModule.Output(Desc = "Transformed dataset", SortOrder = 1)]
            public IDataView OutputData;
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.ModelCombiner", Desc = "Combines a sequence of TransformModels into a single model")]
        public static CombineTransformModelsOutput CombineTransformModels(IHostEnvironment env, CombineTransformModelsInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineTransformModels");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            host.CheckNonEmpty(input.Models, nameof(input.Models));
 
            TransformModel model = input.Models[input.Models.Length - 1];
            for (int i = input.Models.Length - 2; i >= 0; i--)
                model = model.Apply(env, input.Models[i]);
 
            return new CombineTransformModelsOutput { OutputModel = model };
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.ManyHeterogeneousModelCombiner", Desc = "Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel.")]
        public static PredictorModelOutput CombineModels(IHostEnvironment env, PredictorModelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineModels");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            host.CheckNonEmpty(input.TransformModels, nameof(input.TransformModels));
 
            TransformModel model = input.TransformModels[input.TransformModels.Length - 1];
            for (int i = input.TransformModels.Length - 2; i >= 0; i--)
                model = model.Apply(env, input.TransformModels[i]);
            return new PredictorModelOutput() { PredictorModel = input.PredictorModel.Apply(env, model) };
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.TwoHeterogeneousModelCombiner", Desc = "Combines a TransformModel and a PredictorModel into a single PredictorModel.")]
        public static PredictorModelOutput CombineTwoModels(IHostEnvironment env, SimplePredictorModelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineTwoModels");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            return new PredictorModelOutput() { PredictorModel = input.PredictorModel.Apply(env, input.TransformModel) };
        }
 
        [TlcModule.EntryPoint(Name = "Models.DatasetTransformer", Desc = "Applies a TransformModel to a dataset.", UserName = "Apply Transform Model Output")]
        public static ApplyTransformModelOutput Apply(IHostEnvironment env, ApplyTransformModelInput input)
        {
            return new ApplyTransformModelOutput() { OutputData = input.TransformModel.Apply(env, input.Data) };
        }
 
        [TlcModule.EntryPoint(Name = "Models.OvaModelCombiner", Desc = "Combines a sequence of PredictorModels into a single model")]
        public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, CombineOvaPredictorModelsInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineOvaModels");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            host.CheckNonEmpty(input.ModelArray, nameof(input.ModelArray));
            // Something tells me we should put normalization as part of macro expansion, but since i get
            // subgraph instead of learner it's a bit tricky to get learner and decide should we add
            // normalization node or not, plus everywhere in code we leave that responsibility to TransformModel.
            var normalizedView = input.ModelArray[0].TransformModel.Apply(host, input.TrainingData);
            using (var ch = host.Start("CombineOvaModels"))
            {
                var schema = normalizedView.Schema;
                var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.LabelColumnName),
                    input.LabelColumnName,
                    DefaultColumnNames.Label);
                var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.FeatureColumnName),
                    input.FeatureColumnName, DefaultColumnNames.Features);
                var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.ExampleWeightColumnName),
                    input.ExampleWeightColumnName, DefaultColumnNames.Weight);
                var data = new RoleMappedData(normalizedView, label, feature, null, weight);
 
                return new PredictorModelOutput
                {
                    PredictorModel = new PredictorModelImpl(env, data, input.TrainingData,
                    OneVersusAllModelParameters.Create(host, input.UseProbabilities,
                            input.ModelArray.Select(p => p.Predictor as IPredictorProducing<float>).ToArray()))
                };
            }
        }
    }
}