File: EntryPoints\SummarizePredictor.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.IO;
using System.Text;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
 
[assembly: EntryPointModule(typeof(SummarizePredictor))]
 
namespace Microsoft.ML.EntryPoints
{
    [BestFriend]
    internal static class SummarizePredictor
    {
        public abstract class InputBase
        {
            [Argument(ArgumentType.Required, ShortName = "predictorModel", HelpText = "The predictor to summarize")]
            public PredictorModel PredictorModel;
        }
 
        public sealed class Input : InputBase
        {
        }
 
        [TlcModule.EntryPoint(Name = "Models.Summarizer", Desc = "Summarize a linear regression predictor.")]
        public static CommonOutputs.SummaryOutput Summarize(IHostEnvironment env, SummarizePredictor.Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("LinearRegressionPredictor");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            RoleMappedData rmd;
            IPredictor predictor;
            input.PredictorModel.PrepareData(host, new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema), out rmd, out predictor);
 
            var output = new CommonOutputs.SummaryOutput();
            output.Summary = GetSummaryAndStats(host, predictor, rmd.Schema, out output.Stats);
            return output;
        }
 
        [BestFriend]
        internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
        {
            var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
            while (calibrated != null)
            {
                predictor = calibrated.WeaklyTypedSubModel;
                calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
            }
 
            IDataView summary = null;
            stats = null;
            var dvGetter = predictor as ICanGetSummaryAsIDataView;
            var rowGetter = predictor as ICanGetSummaryAsIRow;
            if (dvGetter != null)
                summary = dvGetter.GetSummaryDataView(schema);
            if (rowGetter != null)
            {
                var row = rowGetter.GetSummaryIRowOrNull(schema);
                env.Check(dvGetter == null || row == null,
                    "Predictor outputs two summary data views, don't know which one to choose");
                if (row != null)
                    summary = RowCursorUtils.RowAsDataView(env, row);
                var statsRow = rowGetter.GetStatsIRowOrNull(schema);
                if (statsRow != null)
                    stats = RowCursorUtils.RowAsDataView(env, statsRow);
            }
            if (dvGetter == null && rowGetter == null)
            {
                var bldr = new ArrayDataViewBuilder(env);
                var summaryModel = predictor as ICanSaveSummary;
 
                // Save a data view containing one row and one column with the model summary.
                if (summaryModel != null)
                {
                    var sb = new StringBuilder();
                    using (StringWriter sw = new StringWriter(sb))
                        summaryModel.SaveSummary(sw, schema);
                    bldr.AddColumn("Summary", sb.ToString());
                }
                else
                    bldr.AddColumn("PredictorName", predictor.GetType().ToString());
                summary = bldr.GetDataView();
            }
            env.AssertValue(summary);
            return summary;
        }
    }
}