// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.IO;
using System.IO.Compression;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.Ensemble;
[assembly: LoadableClass(typeof(void), typeof(EnsembleCreator), null, typeof(SignatureEntryPointModule), "CreateEnsemble")]
namespace Microsoft.ML.Trainers.Ensemble
/// <summary>
/// A component to combine given models into an ensemble model.
/// </summary>
internal static class EnsembleCreator
/// <summary>
/// These are the combiner options for binary and multi class classifiers.
/// </summary>
public enum ClassifierCombiner
/// <summary>
/// These are the combiner options for regression and anomaly detection.
/// </summary>
public enum ScoreCombiner
public abstract class PipelineInputBase
[Argument(ArgumentType.Required, ShortName = "models", HelpText = "The models to combine into an ensemble", SortOrder = 1)]
public PredictorModel[] Models;
public abstract class InputBase
[Argument(ArgumentType.Required, ShortName = "models", HelpText = "The models to combine into an ensemble", SortOrder = 1)]
public PredictorModel[] Models;
[Argument(ArgumentType.AtMostOnce, ShortName = "validate", HelpText = "Whether to validate that all the pipelines are identical", SortOrder = 5)]
public bool ValidatePipelines = true;
public sealed class ClassifierInput : InputBase
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ClassifierCombiner ModelCombiner = ClassifierCombiner.Median;
public sealed class PipelineClassifierInput : PipelineInputBase
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ClassifierCombiner ModelCombiner = ClassifierCombiner.Median;
public sealed class RegressionInput : InputBase
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ScoreCombiner ModelCombiner = ScoreCombiner.Median;
public sealed class PipelineRegressionInput : PipelineInputBase
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ScoreCombiner ModelCombiner = ScoreCombiner.Median;
public sealed class PipelineAnomalyInput : PipelineInputBase
[Argument(ArgumentType.AtMostOnce, ShortName = "combiner", HelpText = "The combiner used to combine the scores", SortOrder = 2)]
public ScoreCombiner ModelCombiner = ScoreCombiner.Average;
private static void GetPipeline(IHostEnvironment env, InputBase input, out IDataView startingData, out RoleMappedData transformedData)
DataViewSchema inputSchema = null;
startingData = null;
transformedData = null;
byte[][] transformedDataSerialized = null;
string[] transformedDataZipEntryNames = null;
for (int i = 0; i < input.Models.Length; i++)
var model = input.Models[i];
var inputData = new EmptyDataView(env, model.TransformModel.InputSchema);
model.PrepareData(env, inputData, out RoleMappedData transformedDataCur, out IPredictor pred);
if (inputSchema == null)
env.Assert(i == 0);
inputSchema = model.TransformModel.InputSchema;
startingData = inputData;
transformedData = transformedDataCur;
else if (input.ValidatePipelines)
using (var ch = env.Start("Validating pipeline"))
if (transformedDataSerialized == null)
ch.Assert(transformedDataZipEntryNames == null);
SerializeRoleMappedData(env, ch, transformedData, out transformedDataSerialized,
out transformedDataZipEntryNames);
CheckSamePipeline(env, ch, transformedDataCur, transformedDataSerialized, transformedDataZipEntryNames);
[TlcModule.EntryPoint(Name = "Models.BinaryEnsemble", Desc = "Combine binary classifiers into an ensemble", UserName = EnsembleTrainer.UserNameValue)]
public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, ClassifierInput input)
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData);
var args = new EnsembleTrainer.Arguments();
switch (input.ModelCombiner)
case ClassifierCombiner.Median:
args.OutputCombiner = new MedianFactory();
case ClassifierCombiner.Average:
args.OutputCombiner = new AverageFactory();
case ClassifierCombiner.Vote:
args.OutputCombiner = new VotingFactory();
throw host.Except("Unknown combiner kind");
var trainer = new EnsembleTrainer(host, args);
var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing<float>));
var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble);
var output = new CommonOutputs.BinaryClassificationOutput { PredictorModel = predictorModel };
return output;
[TlcModule.EntryPoint(Name = "Models.RegressionEnsemble", Desc = "Combine regression models into an ensemble", UserName = RegressionEnsembleTrainer.UserNameValue)]
public static CommonOutputs.RegressionOutput CreateRegressionEnsemble(IHostEnvironment env, RegressionInput input)
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData);
var args = new RegressionEnsembleTrainer.Arguments();
switch (input.ModelCombiner)
case ScoreCombiner.Median:
args.OutputCombiner = new MedianFactory();
case ScoreCombiner.Average:
args.OutputCombiner = new AverageFactory();
throw host.Except("Unknown combiner kind");
var trainer = new RegressionEnsembleTrainer(host, args);
var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing<float>));
var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble);
var output = new CommonOutputs.RegressionOutput { PredictorModel = predictorModel };
return output;
[TlcModule.EntryPoint(Name = "Models.BinaryPipelineEnsemble", Desc = "Combine binary classification models into an ensemble")]
public static CommonOutputs.BinaryClassificationOutput CreateBinaryPipelineEnsemble(IHostEnvironment env, PipelineClassifierInput input)
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
IBinaryOutputCombiner combiner;
switch (input.ModelCombiner)
case ClassifierCombiner.Median:
combiner = new Median(host);
case ClassifierCombiner.Average:
combiner = new Average(host);
case ClassifierCombiner.Vote:
combiner = new Voting(host);
throw host.Except("Unknown combiner kind");
var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, AnnotationUtils.Const.ScoreColumnKind.BinaryClassification);
return CreatePipelineEnsemble<CommonOutputs.BinaryClassificationOutput>(host, input.Models, ensemble);
[TlcModule.EntryPoint(Name = "Models.RegressionPipelineEnsemble", Desc = "Combine regression models into an ensemble")]
public static CommonOutputs.RegressionOutput CreateRegressionPipelineEnsemble(IHostEnvironment env, PipelineRegressionInput input)
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
IRegressionOutputCombiner combiner;
switch (input.ModelCombiner)
case ScoreCombiner.Median:
combiner = new Median(host);
case ScoreCombiner.Average:
combiner = new Average(host);
throw host.Except("Unknown combiner kind");
var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, AnnotationUtils.Const.ScoreColumnKind.Regression);
return CreatePipelineEnsemble<CommonOutputs.RegressionOutput>(host, input.Models, ensemble);
[TlcModule.EntryPoint(Name = "Models.MultiClassPipelineEnsemble", Desc = "Combine multiclass classifiers into an ensemble")]
public static CommonOutputs.MulticlassClassificationOutput CreateMulticlassPipelineEnsemble(IHostEnvironment env, PipelineClassifierInput input)
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
IOutputCombiner<VBuffer<Single>> combiner;
switch (input.ModelCombiner)
case ClassifierCombiner.Median:
combiner = new MultiMedian(host, new MultiMedian.Options() { Normalize = true });
case ClassifierCombiner.Average:
combiner = new MultiAverage(host, new MultiAverage.Options() { Normalize = true });
case ClassifierCombiner.Vote:
combiner = new MultiVoting(host);
throw host.Except("Unknown combiner kind");
var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification);
return CreatePipelineEnsemble<CommonOutputs.MulticlassClassificationOutput>(host, input.Models, ensemble);
[TlcModule.EntryPoint(Name = "Models.AnomalyPipelineEnsemble", Desc = "Combine anomaly detection models into an ensemble")]
public static CommonOutputs.AnomalyDetectionOutput CreateAnomalyPipelineEnsemble(IHostEnvironment env, PipelineAnomalyInput input)
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CombineModels");
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(input.Models, nameof(input.Models));
IRegressionOutputCombiner combiner;
switch (input.ModelCombiner)
case ScoreCombiner.Median:
combiner = new Median(host);
case ScoreCombiner.Average:
combiner = new Average(host);
throw host.Except("Unknown combiner kind");
var ensemble = SchemaBindablePipelineEnsembleBase.Create(host, input.Models, combiner, AnnotationUtils.Const.ScoreColumnKind.AnomalyDetection);
return CreatePipelineEnsemble<CommonOutputs.AnomalyDetectionOutput>(host, input.Models, ensemble);
private static TOut CreatePipelineEnsemble<TOut>(IHostEnvironment env, PredictorModel[] predictors, SchemaBindablePipelineEnsembleBase ensemble)
where TOut : CommonOutputs.TrainerOutput, new()
var inputSchema = predictors[0].TransformModel.InputSchema;
var dv = new EmptyDataView(env, inputSchema);
// The role mappings are specific to the individual predictors.
var rmd = new RoleMappedData(dv);
var predictorModel = new PredictorModelImpl(env, rmd, dv, ensemble);
var output = new TOut { PredictorModel = predictorModel };
return output;
/// <summary>
/// This method takes a <see cref="RoleMappedData"/> as input, saves it as an in-memory <see cref="ZipArchive"/>
/// and returns two arrays indexed by the entries in the zip:
/// 1. An array of byte arrays, containing the byte sequences of each entry.
/// 2. An array of strings, containing the name of each entry.
/// This method is used for comparing pipelines. Its outputs can be passed to <see cref="CheckSamePipeline"/>
/// to check if this pipeline is identical to another pipeline.
/// </summary>
internal static void SerializeRoleMappedData(IHostEnvironment env, IChannel ch, RoleMappedData data,
out byte[][] dataSerialized, out string[] dataZipEntryNames)
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(data, nameof(data));
using (var ms = new MemoryStream())
TrainUtils.SaveModel(env, ch, ms, null, data);
var zip = new ZipArchive(ms);
var entries = zip.Entries.OrderBy(e => e.FullName).ToArray();
dataSerialized = new byte[Utils.Size(entries)][];
dataZipEntryNames = new string[Utils.Size(entries)];
for (int i = 0; i < Utils.Size(entries); i++)
dataZipEntryNames[i] = entries[i].FullName;
dataSerialized[i] = new byte[entries[i].Length];
using (var s = entries[i].Open())
s.ReadExactly(dataSerialized[i], 0, (int)entries[i].Length);
/// <summary>
/// This method compares two pipelines to make sure they are identical. The first pipeline is passed
/// as a <see cref="RoleMappedData"/>, and the second as a double byte array and a string array. The double
/// byte array and the string array are obtained by calling <see cref="SerializeRoleMappedData"/> on the
/// second pipeline.
/// The comparison is done by saving <see ref="dataToCompare"/> as an in-memory <see cref="ZipArchive"/>,
/// and for each entry in it, comparing its name, and the byte sequence to the corresponding entries in
/// <see ref="dataZipEntryNames"/> and <see ref="dataSerialized"/>.
/// This method throws if for any of the entries the name/byte sequence are not identical.
/// </summary>
internal static void CheckSamePipeline(IHostEnvironment env, IChannel ch,
RoleMappedData dataToCompare, byte[][] dataSerialized, string[] dataZipEntryNames)
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(dataToCompare, nameof(dataToCompare));
ch.CheckValue(dataSerialized, nameof(dataSerialized));
ch.CheckValue(dataZipEntryNames, nameof(dataZipEntryNames));
if (dataZipEntryNames.Length != dataSerialized.Length)
throw ch.ExceptParam(nameof(dataSerialized),
$"The length of {nameof(dataSerialized)} must be equal to the length of {nameof(dataZipEntryNames)}");
using (var ms = new MemoryStream())
// REVIEW: This can be done more efficiently by adding a custom type of repository that
// doesn't actually save the data, but upon stream closure compares the results to the given repository
// and then discards it. Currently, however, this cannot be done because ModelSaveContext does not use
// an abstract class/interface, but rather the RepositoryWriter class.
TrainUtils.SaveModel(env, ch, ms, null, dataToCompare);
string errorMsg = "Models contain different pipelines, cannot ensemble them.";
var zip = new ZipArchive(ms);
var entries = zip.Entries.OrderBy(e => e.FullName).ToArray();
ch.Check(dataSerialized.Length == Utils.Size(entries));
byte[] buffer = null;
for (int i = 0; i < dataSerialized.Length; i++)
ch.Check(dataZipEntryNames[i] == entries[i].FullName, errorMsg);
int len = dataSerialized[i].Length;
if (Utils.Size(buffer) < len)
buffer = new byte[len];
using (var s = entries[i].Open())
int bytesRead = s.Read(buffer, 0, len);
ch.Check(bytesRead == len, errorMsg);
for (int j = 0; j < len; j++)
ch.Check(buffer[j] == dataSerialized[i][j], errorMsg);
if (s.Read(buffer, 0, 1) > 0)
throw env.Except(errorMsg);