File: EntryPoints\PipelineEnsemble.cs
Web Access
Project: src\src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj (Microsoft.ML.Ensemble)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.Ensemble;
 
[assembly: EntryPointModule(typeof(PipelineEnsemble))]
 
namespace Microsoft.ML.Trainers.Ensemble
{
    internal static class PipelineEnsemble
    {
        public sealed class SummaryOutput
        {
            [TlcModule.Output(Desc = "The summaries of the individual predictors")]
            public IDataView[] Summaries;
 
            [TlcModule.Output(Desc = "The model statistics of the individual predictors")]
            public IDataView[] Stats;
        }
 
        [TlcModule.EntryPoint(Name = "Models.EnsembleSummary", Desc = "Summarize a pipeline ensemble predictor.")]
        public static SummaryOutput Summarize(IHostEnvironment env, SummarizePredictor.Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("PipelineEnsemblePredictor");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            input.PredictorModel.PrepareData(host,
                new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema),
                out RoleMappedData rmd, out IPredictor predictor);
 
            var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
            while (calibrated != null)
            {
                predictor = calibrated.WeaklyTypedSubModel;
                calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
            }
            var ensemble = predictor as SchemaBindablePipelineEnsembleBase;
            host.CheckUserArg(ensemble != null, nameof(input.PredictorModel.Predictor), "Predictor is not a pipeline ensemble predictor");
 
            var summaries = new IDataView[ensemble.PredictorModels.Length];
            var stats = new IDataView[ensemble.PredictorModels.Length];
            for (int i = 0; i < ensemble.PredictorModels.Length; i++)
            {
                var pm = ensemble.PredictorModels[i];
 
                pm.PrepareData(host, new EmptyDataView(host, pm.TransformModel.InputSchema), out rmd, out IPredictor pred);
                summaries[i] = SummarizePredictor.GetSummaryAndStats(host, pred, rmd.Schema, out stats[i]);
            }
            return new SummaryOutput() { Summaries = summaries, Stats = stats };
        }
    }
}