File: Trainer\EnsembleModelParametersBase.cs
Web Access
Project: src\src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj (Microsoft.ML.Ensemble)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.Ensemble
{
    internal abstract class EnsembleModelParametersBase<TOutput> : ModelParametersBase<TOutput>,
        IPredictorProducing<TOutput>, ICanSaveInTextFormat, ICanSaveSummary
    {
        private const string SubPredictorFmt = "SubPredictor_{0:000}";
 
        private protected readonly FeatureSubsetModel<TOutput>[] Models;
        private protected readonly IOutputCombiner<TOutput> Combiner;
        private protected readonly Single[] Weights;
 
        private const uint VerOld = 0x00010002;
 
        private protected EnsembleModelParametersBase(IHostEnvironment env, string name, FeatureSubsetModel<TOutput>[] models,
            IOutputCombiner<TOutput> combiner, Single[] weights)
            : base(env, name)
        {
 
            Host.Check(Utils.Size(models) > 0, "Ensemble was created with no models.");
            Host.Check(weights == null || weights.Length == models.Length);
 
            Models = models;
            Combiner = combiner;
            Weights = weights;
        }
 
        private protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
            : base(env, name, ctx)
        {
            // *** Binary format ***
            // int: model count
            // int: weight count (0 or model count)
            // Float[]: weights
            // for each model:
            //   int: number of SelectedFeatures (in bits)
            //   byte[]: selected features (as many as needed for number of bits == (numSelectedFeatures + 7) / 8)
            //   int: number of Metric values
            //   for each Metric:
            //     Float: metric value
            //     int: metric name (id of the metric name in the string table)
            //     in version 0x0001x0002:
            //       bool: is the metric averaged
 
            int count = ctx.Reader.ReadInt32();
            Host.CheckDecode(count > 0);
 
            int weightCount = ctx.Reader.ReadInt32();
            Host.CheckDecode(weightCount == 0 || weightCount == count);
            Weights = ctx.Reader.ReadFloatArray(weightCount);
 
            Models = new FeatureSubsetModel<TOutput>[count];
            var ver = ctx.Header.ModelVerWritten;
            for (int i = 0; i < count; i++)
            {
                ctx.LoadModel<IPredictor, SignatureLoadModel>(Host, out IPredictor p, string.Format(SubPredictorFmt, i));
                var predictor = p as IPredictorProducing<TOutput>;
                Host.Check(predictor != null, "Inner predictor type not compatible with the ensemble type.");
                var features = ctx.Reader.ReadBitArray();
                int numMetrics = ctx.Reader.ReadInt32();
                Host.CheckDecode(numMetrics >= 0);
                var metrics = new KeyValuePair<string, double>[numMetrics];
                for (int j = 0; j < numMetrics; j++)
                {
                    var metricValue = ctx.Reader.ReadFloat();
                    var metricName = ctx.LoadStringOrNull();
                    if (ver == VerOld)
                        ctx.Reader.ReadBoolByte();
                    metrics[j] = new KeyValuePair<string, double>(metricName, metricValue);
                }
                Models[i] = new FeatureSubsetModel<TOutput>(predictor, features, metrics);
            }
            ctx.LoadModel<IOutputCombiner<TOutput>, SignatureLoadModel>(Host, out Combiner, @"Combiner");
        }
 
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);
 
            // *** Binary format ***
            // int: model count
            // int: weight count (0 or model count)
            // Single[]: weights
            // for each model:
            //   int: number of SelectedFeatures (in bits)
            //   byte[]: selected features (as many as needed for number of bits == (numSelectedFeatures + 7) / 8)
            //   int: number of Metric values
            //   for each Metric:
            //     Single: metric value
            //     int: metric name (id of the metric name in the string table)
 
            ctx.Writer.Write(Models.Length);
            ctx.Writer.WriteSingleArray(Weights);
 
            // Save other streams.
            for (int i = 0; i < Models.Length; i++)
            {
                var model = Models[i];
                ctx.SaveModel(model.Predictor, string.Format(SubPredictorFmt, i));
                Host.AssertValueOrNull(model.SelectedFeatures);
                ctx.Writer.WriteBitArray(model.SelectedFeatures);
                Host.AssertValueOrNull(model.Metrics);
                int numMetrics = Utils.Size(model.Metrics);
                ctx.Writer.Write(numMetrics);
                for (int j = 0; j < numMetrics; j++)
                {
                    var metric = model.Metrics[j];
                    ctx.Writer.Write((Single)metric.Value);
                    ctx.SaveStringOrNull(metric.Key);
                }
            }
            ctx.SaveModel(Combiner, @"Combiner");
        }
 
        /// <summary>
        /// Output the INI model to a given writer
        /// </summary>
        void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
        {
            using (var ch = Host.Start("SaveAsText"))
            {
                for (int i = 0; i < Models.Length; i++)
                {
                    writer.WriteLine(";; Partition model {0}", i);
                    writer.WriteLine(";; Weight={0}", (Weights != null ? Weights[i] : 1));
                    PredictorUtils.SaveText(ch, Models[i].Predictor, schema, writer);
                }
            }
        }
 
        /// <summary>
        /// Saves the model summary
        /// </summary>
        void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
        {
            for (int i = 0; i < Models.Length; i++)
            {
                writer.WriteLine(";; Partition model {0}", i);
                writer.WriteLine(";; Weight={0}", (Weights != null ? Weights[i] : 1));
 
                // REVIEW: The featureName Collection names may vary for different base learners.
                // How do we get the right collection for the base learners?
                if (Models[i].Predictor is ICanSaveSummary summaryModel)
                    summaryModel.SaveSummary(writer, schema);
                else
                    writer.WriteLine("The Model {0} does not support saving summaries", Models[i].GetType().Name);
            }
        }
    }
}