File: Standard\MulticlassClassification\OneVersusAllTrainer.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
using Newtonsoft.Json.Linq;
 
[assembly: LoadableClass(OneVersusAllTrainer.Summary, typeof(OneVersusAllTrainer), typeof(OneVersusAllTrainer.Options),
    new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) },
    OneVersusAllTrainer.UserNameValue,
    OneVersusAllTrainer.LoadNameValue)]
 
[assembly: LoadableClass(typeof(OneVersusAllModelParameters), null, typeof(SignatureLoadModel),
    "OVA Executor",
    OneVersusAllModelParameters.LoaderSignature)]
 
[assembly: EntryPointModule(typeof(OneVersusAllModelParameters))]
namespace Microsoft.ML.Trainers
{
    using CR = RoleMappedSchema.ColumnRole;
    using TDistPredictor = IDistPredictorProducing<float, float>;
    using TScalarPredictor = IPredictorProducing<float>;
    using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
    /// <summary>
    /// The <see cref="IEstimator{TTransformer}"/> for training a one-versus-all multi-class classifier that uses the specified binary classifier.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    /// To create this trainer, use [OneVersusAll](xref:Microsoft.ML.StandardTrainersCatalog.OneVersusAll``1(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,Microsoft.ML.Trainers.ITrainerEstimator{Microsoft.ML.Data.BinaryPredictionTransformer{``0},``0},System.String,System.Boolean,Microsoft.ML.IEstimator{Microsoft.ML.ISingleFeaturePredictionTransformer{Microsoft.ML.Calibrators.ICalibrator}},System.Int32,System.Boolean)).
    ///
    /// [!include[io](~/../docs/samples/docs/api-reference/io-columns-multiclass-classification.md)]
    ///
    /// ### Trainer Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Machine learning task | Multiclass classification |
    /// | Is normalization required? | Depends on the underlying binary classifier |
    /// | Is caching required? | Yes |
    /// | Required NuGet in addition to Microsoft.ML | None |
    /// | Exportable to ONNX | Yes |
    ///
    /// ### Training Algorithm Details
    /// In one-versus-all (OVA) strategy, a binary classification algorithm is used to train one classifier for each class,
    /// which distinguishes that class from all other classes. Prediction is then performed by running
    /// these binary classifiers and choosing the prediction with the highest confidence score.
    /// This algorithm can be used with any of the binary classifiers in ML.NET. A few binary classifiers
    /// already have implementation for multi-class problems, thus users can choose either one depending on the context.
    /// The OVA version of a binary classifier, such as wrapping a <xref:Microsoft.ML.Trainers.LightGbm.LightGbmBinaryTrainer>,
    /// can be different from <xref:Microsoft.ML.Trainers.LightGbm.LightGbmMulticlassTrainer>, which develops a multi-class classifier directly.
    /// Note that even if the classifier indicates that it does not need caching, OneVersusAll will always
    /// request caching, as it will be performing multiple passes over the data set.
    /// This trainer will request normalization from the data pipeline if the classifier indicates it would benefit from it.
    ///
    /// This can allow you to exploit trainers that do not naturally have a
    /// multiclass option, for example, using the <xref:Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer>
    /// to solve a multiclass problem.
    /// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases
    /// where the trainer has a multiclass option, but using it directly is not
    /// practical due to, usually, memory constraints. For example, while a multiclass
    /// logistic regression is a more principled way to solve a multiclass problem, it
    /// requires that the trainer store a lot more intermediate state in the form of
    /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one
    /// as would be needed for a one-versus-all classification model.
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]>
    /// </format>
    /// </remarks>
    /// <seealso cref="StandardTrainersCatalog.OneVersusAll{TModel}(MulticlassClassificationCatalog.MulticlassClassificationTrainers, ITrainerEstimator{BinaryPredictionTransformer{TModel}, TModel}, string, bool, IEstimator{ISingleFeaturePredictionTransformer{ICalibrator}}, int, bool)" />
    public sealed class OneVersusAllTrainer : MetaMulticlassTrainer<MulticlassPredictionTransformer<OneVersusAllModelParameters>, OneVersusAllModelParameters>
    {
        internal const string LoadNameValue = "OVA";
        internal const string UserNameValue = "One-vs-All";
        internal const string Summary = "In this strategy, a binary classification algorithm is used to train one classifier for each class, "
            + "which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, "
            + "and choosing the prediction with the highest confidence score.";
 
        private readonly Options _options;
 
        /// <summary>
        /// Options passed to <see cref="OneVersusAllTrainer"/>
        /// </summary>
        internal sealed class Options : OptionsBase
        {
            /// <summary>
            /// Whether to use probabilities (vs. raw outputs) to identify top-score category.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Use probability or margins to determine max", ShortName = "useprob")]
            [TGUI(Label = "Use Probability", Description = "Use probabilities (vs. raw outputs) to identify top-score category")]
            public bool UseProbabilities = true;
        }
 
        /// <summary>
        /// Constructs a <see cref="OneVersusAllTrainer"/> trainer supplying a <see cref="Options"/>.
        /// </summary>
        /// <param name="env">The private <see cref="IHostEnvironment"/> for this estimator.</param>
        /// <param name="options">The legacy <see cref="Options"/></param>
        internal OneVersusAllTrainer(IHostEnvironment env, Options options)
            : base(env, options, LoadNameValue)
        {
            _options = options;
        }
 
        /// <summary>
        /// Initializes a new instance of <see cref="OneVersusAllTrainer"/>.
        /// </summary>
        /// <param name="env">The <see cref="IHostEnvironment"/> instance.</param>
        /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param>
        /// <param name="calibrator">The calibrator. If a calibrator is not provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
        /// <param name="labelColumnName">The name of the label colum.</param>
        /// <param name="imputeMissingLabelsAsNegative">If true will treat missing labels as negative labels.</param>
        /// <param name="maximumCalibrationExampleCount">Number of instances to train the calibrator.</param>
        /// <param name="useProbabilities">Use probabilities (vs. raw outputs) to identify top-score category.</param>
        internal OneVersusAllTrainer(IHostEnvironment env,
            TScalarTrainer binaryEstimator,
            string labelColumnName = DefaultColumnNames.Label,
            bool imputeMissingLabelsAsNegative = false,
            ICalibratorTrainer calibrator = null,
            int maximumCalibrationExampleCount = 1000000000,
            bool useProbabilities = true)
         : base(env,
               new Options
               {
                   ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative,
                   MaxCalibrationExamples = maximumCalibrationExampleCount,
               },
               LoadNameValue, labelColumnName, binaryEstimator, calibrator)
        {
            Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null.");
            _options = (Options)Args;
            _options.UseProbabilities = useProbabilities;
        }
 
        private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count)
        {
            // Train one-vs-all models.
            var predictors = new TScalarPredictor[count];
            for (int i = 0; i < predictors.Length; i++)
            {
                ch.Info($"Training learner {i}");
                predictors[i] = TrainOne(ch, Trainer, data, i).Model;
            }
            return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors);
        }
 
        private ISingleFeaturePredictionTransformer<TScalarPredictor> TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls)
        {
            var view = MapLabels(data, cls);
 
            string trainerLabel = data.Schema.Label.Value.Name;
 
            // REVIEW: In principle we could support validation sets and the like via the train context, but
            // this is currently unsupported.
            var transformer = trainer.Fit(view);
 
            if (_options.UseProbabilities)
            {
                var calibratedModel = transformer.Model as TDistPredictor;
 
                // REVIEW: restoring the RoleMappedData, as much as we can.
                // not having the weight column on the data passed to the TrainCalibrator should be addressed.
                var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName);
 
                if (calibratedModel == null)
                    calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor;
 
                Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface");
                return new BinaryPredictionTransformer<TScalarPredictor>(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName);
            }
 
            return new BinaryPredictionTransformer<TScalarPredictor>(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);
        }
 
        private IDataView MapLabels(RoleMappedData data, int cls)
        {
            var label = data.Schema.Label.Value;
            Host.Assert(!label.IsHidden);
            Host.Assert(label.Type.GetKeyCount() > 0 || label.Type == NumberDataViewType.Single || label.Type == NumberDataViewType.Double);
 
            if (label.Type.GetKeyCount() > 0)
            {
                // Key values are 1-based.
                uint key = (uint)(cls + 1);
                return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => key == val, data);
            }
 
            throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {label.Type.RawType}");
        }
 
        /// <summary> Trains a <see cref="MulticlassPredictionTransformer{OneVersusAllModelParameters}"/> model.</summary>
        /// <param name="input">The input data.</param>
        /// <returns>A <see cref="MulticlassPredictionTransformer{OneVersusAllModelParameters}"/> model./></returns>
        public override MulticlassPredictionTransformer<OneVersusAllModelParameters> Fit(IDataView input)
        {
            var roles = new KeyValuePair<CR, string>[1];
            roles[0] = new KeyValuePair<CR, string>(new CR(DefaultColumnNames.Label), LabelColumn.Name);
            var td = new RoleMappedData(input, roles);
 
            td.CheckMulticlassLabel(out var numClasses);
 
            var predictors = new TScalarPredictor[numClasses];
            string featureColumn = null;
 
            using (var ch = Host.Start("Fitting"))
            {
                for (int i = 0; i < predictors.Length; i++)
                {
                    ch.Info($"Training learner {i}");
 
                    if (i == 0)
                    {
                        var transformer = TrainOne(ch, Trainer, td, i);
                        featureColumn = transformer.FeatureColumnName;
                    }
 
                    predictors[i] = TrainOne(ch, Trainer, td, i).Model;
                }
            }
 
            return new MulticlassPredictionTransformer<OneVersusAllModelParameters>(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name);
        }
    }
 
    /// <summary>
    /// Model parameters for <see cref="OneVersusAllTrainer"/>.
    /// </summary>
    public sealed class OneVersusAllModelParameters :
        ModelParametersBase<VBuffer<float>>,
        IValueMapper,
        ICanSaveInSourceCode,
        ICanSaveInTextFormat,
        ISingleCanSavePfa,
        ISingleCanSaveOnnx
    {
        internal const string LoaderSignature = "OVAExec";
        internal const string RegistrationName = "OVAPredictor";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "TLC OVA ",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(OneVersusAllModelParameters).Assembly.FullName);
        }
 
        private const string SubPredictorFmt = "SubPredictor_{0:000}";
 
        private readonly ImplBase _impl;
 
        /// <summary>
        /// Retrieves the model parameters.
        /// </summary>
        internal ImmutableArray<object> SubModelParameters => _impl.Predictors.Cast<object>().ToImmutableArray();
 
        /// <summary>
        /// The type of the prediction task.
        /// </summary>
        private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
 
        /// <summary>
        /// Function applied to output of predictors. Assume that we have n predictors (one per class) and for the i-th predictor,
        /// y_i is its raw output and p_i is its probability output. Note that not all predictors are able to produce probability output.
        /// <para>
        /// <see cref="Raw"/>: output the result of predictors without post-processing. Output is [y_1, ..., y_n].
        /// <see cref="ProbabilityNormalization"/>: fetch probability output of each class probability from provided predictors and make sure the sume of class probabilities is one.
        /// Output is [p_1 / (p_1 + ... + p_n), ..., p_n / (p_1 + ... + p_n)].
        /// <see cref="Softmax"/>: Generate probability by feeding raw outputs to softmax function. Output is [z_1, ..., z_n], where z_i is exp(y_i) / (exp(y_1) + ... + exp(y_n)).
        /// </para>
        /// </summary>
        [BestFriend]
        internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 };
 
        private DataViewType DistType { get; }
 
        bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa;
 
        [BestFriend]
        internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors)
        {
            ImplBase impl;
 
            using (var ch = host.Start("Creating OVA predictor"))
            {
                if (outputFormula == OutputFormula.Softmax)
                {
                    impl = new ImplSoftmax(predictors);
                    return new OneVersusAllModelParameters(host, impl);
                }
 
                // Caller of this function asks for probability output. We check if input predictor can produce probability.
                // If that predictor can't produce probability, ivmd will be null.
                IValueMapperDist ivmd = null;
                if (outputFormula == OutputFormula.ProbabilityNormalization &&
                    ((ivmd = predictors[0] as IValueMapperDist) == null ||
                        ivmd.OutputType != NumberDataViewType.Single ||
                        ivmd.DistType != NumberDataViewType.Single))
                {
                    ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities.");
                    ivmd = null;
                }
 
                // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability.
                if (ivmd != null)
                {
                    var dists = new IValueMapperDist[predictors.Length];
                    for (int i = 0; i < predictors.Length; ++i)
                        dists[i] = (IValueMapperDist)predictors[i];
                    impl = new ImplDist(dists);
                }
                else
                    impl = new ImplRaw(predictors);
            }
 
            return new OneVersusAllModelParameters(host, impl);
        }
 
        [BestFriend]
        internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, TScalarPredictor[] predictors)
        {
            var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw;
 
            return Create(host, outputFormula, predictors);
        }
 
        /// <summary>
        /// Create a <see cref="OneVersusAllModelParameters"/> from an array of predictors.
        /// </summary>
        [BestFriend]
        internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors)
        {
            Contracts.CheckValue(host, nameof(host));
            host.CheckNonEmpty(predictors, nameof(predictors));
            return Create(host, OutputFormula.ProbabilityNormalization, predictors);
        }
 
        private OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl)
                : base(env, RegistrationName)
        {
            Host.AssertValue(impl, nameof(impl));
            Host.Assert(Utils.Size(impl.Predictors) > 0);
 
            _impl = impl;
            DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length);
        }
 
        private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx)
                : base(env, RegistrationName, ctx)
        {
            // *** Binary format ***
            // byte: OutputFormula as byte
            // int: predictor count
            OutputFormula outputFormula = (OutputFormula)ctx.Reader.ReadByte();
            int len = ctx.Reader.ReadInt32();
            Host.CheckDecode(len > 0);
 
            if (outputFormula == OutputFormula.Raw)
            {
                var predictors = new TScalarPredictor[len];
                LoadPredictors(Host, predictors, ctx);
                _impl = new ImplRaw(predictors);
            }
            else if (outputFormula == OutputFormula.ProbabilityNormalization)
            {
                var predictors = new IValueMapperDist[len];
                LoadPredictors(Host, predictors, ctx);
                _impl = new ImplDist(predictors);
            }
            else if (outputFormula == OutputFormula.Softmax)
            {
                var predictors = new TScalarPredictor[len];
                LoadPredictors(Host, predictors, ctx);
                _impl = new ImplSoftmax(predictors);
            }
 
            DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length);
        }
 
        internal static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            return new OneVersusAllModelParameters(env, ctx);
        }
 
        private static void LoadPredictors<TPredictor>(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx)
            where TPredictor : class
        {
            for (int i = 0; i < predictors.Length; i++)
                ctx.LoadModel<TPredictor, SignatureLoadModel>(env, out predictors[i], string.Format(SubPredictorFmt, i));
        }
 
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);
            ctx.SetVersionInfo(GetVersionInfo());
 
            var preds = _impl.Predictors;
 
            // *** Binary format ***
            // byte: _impl.OutputFormula as byte
            // int: predictor count
            byte[] outputFormula = { (byte)_impl.OutputFormula };
            ctx.Writer.WriteBytesNoCount(outputFormula, 1);
            ctx.Writer.Write(preds.Length);
 
            // Save other streams.
            for (int i = 0; i < preds.Length; i++)
                ctx.SaveModel(preds[i], string.Format(SubPredictorFmt, i));
        }
 
        JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.CheckValue(input, nameof(input));
            return _impl.SaveAsPfa(ctx, input);
        }
 
        DataViewType IValueMapper.InputType
        {
            get { return _impl.InputType; }
        }
 
        DataViewType IValueMapper.OutputType
        {
            get { return DistType; }
        }
        ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
        {
            Host.Check(typeof(TIn) == typeof(VBuffer<float>));
            Host.Check(typeof(TOut) == typeof(VBuffer<float>));
 
            return (ValueMapper<TIn, TOut>)(Delegate)_impl.GetMapper();
        }
 
        void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema)
        {
            Host.CheckValue(writer, nameof(writer));
            Host.CheckValue(schema, nameof(schema));
 
            var preds = _impl.Predictors;
            writer.WriteLine("double[] outputs = new double[{0}];", preds.Length);
 
            for (int i = 0; i < preds.Length; i++)
            {
                var saveInSourceCode = preds[i] as ICanSaveInSourceCode;
                Host.Check(saveInSourceCode != null, "Saving in code is not supported.");
 
                writer.WriteLine("{");
                saveInSourceCode.SaveAsCode(writer, schema);
                writer.WriteLine("outputs[{0}] = output;", i);
                writer.WriteLine("}");
            }
        }
 
        void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
        {
            Host.CheckValue(writer, nameof(writer));
            Host.CheckValue(schema, nameof(schema));
 
            var preds = _impl.Predictors;
 
            for (int i = 0; i < preds.Length; i++)
            {
                var saveInText = preds[i] as ICanSaveInTextFormat;
                Host.Check(saveInText != null, "Saving in text is not supported.");
 
                writer.WriteLine("#region: class-{0} classifier", i);
                saveInText.SaveAsText(writer, schema);
 
                writer.WriteLine("#endregion: class-{0} classifier", i);
                writer.WriteLine();
            }
        }
 
        bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _impl.CanSaveOnnx(ctx);
 
        bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) => _impl.SaveAsOnnx(ctx, outputNames, featureColumn);
 
        private abstract class ImplBase : ISingleCanSavePfa, ISingleCanSaveOnnx
        {
            public OutputFormula OutputFormula;
            public abstract DataViewType InputType { get; }
            public abstract IValueMapper[] Predictors { get; }
            public abstract bool CanSavePfa { get; }
            public abstract ValueMapper<VBuffer<float>, VBuffer<float>> GetMapper();
            public abstract JToken SaveAsPfa(BoundPfaContext ctx, JToken input);
 
            public bool CanSaveOnnx(OnnxContext ctx) => Predictors.All(pred => (pred as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true);
 
            public abstract bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn);
 
            protected bool IsValid(IValueMapper mapper, ref VectorDataViewType inputType)
            {
                Contracts.AssertValueOrNull(mapper);
                Contracts.AssertValueOrNull(inputType);
 
                if (mapper == null)
                    return false;
                if (mapper.OutputType != NumberDataViewType.Single)
                    return false;
                if (!(mapper.InputType is VectorDataViewType mapperVectorType) || mapperVectorType.ItemType != NumberDataViewType.Single)
                    return false;
                if (inputType == null)
                    inputType = mapperVectorType;
                else if (inputType.Size != mapperVectorType.Size)
                {
                    if (inputType.Size == 0)
                        inputType = mapperVectorType;
                    else if (mapperVectorType.Size != 0)
                        return false;
                }
                return true;
            }
 
            public string[] SaveAsOnnxPreProcess(OnnxContext ctx, string featureColumn, bool clipToZero)
            {
                string[] outputs = new string[Predictors.Length];
 
                string[] localOutputNames = { DefaultColumnNames.PredictedLabel, DefaultColumnNames.Score, DefaultColumnNames.Probability };
 
                for (int i = 0; i < Predictors.Length; i++)
                {
                    var predictorOutputNames = new string[localOutputNames.Length];
 
                    predictorOutputNames[0] = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, $"{DefaultColumnNames.PredictedLabel}_{i}", true);
                    predictorOutputNames[1] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"{DefaultColumnNames.Score}_{i}", true);
                    predictorOutputNames[2] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"{DefaultColumnNames.Probability}_{i}", true);
 
                    string clipInput = predictorOutputNames[2];
 
                    var pred = Predictors[i] as ISingleCanSaveOnnx;
                    Contracts.AssertValue(pred);
                    pred.SaveAsOnnx(ctx, predictorOutputNames, featureColumn);
 
                    if (clipToZero)
                    {
                        var clipOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"ClipOutput_{i}", true);
                        outputs[i] = clipOutput;
 
                        string opType = "Clip";
                        var zeroVar = ctx.AddInitializer(0.0f, "Zero");
                        var clipNode = ctx.CreateNode(opType, new[] { clipInput, zeroVar }, new[] { outputs[i] }, ctx.GetNodeName(opType), "");
                    }
                    else
                        outputs[i] = predictorOutputNames[1];
                }
                return outputs;
            }
 
            public void SaveAsOnnxPostProcess(OnnxContext ctx, string inputName, string[] outputNames)
            {
                Contracts.Assert(outputNames.Length >= 2);
 
                string opType;
                opType = "ArgMax";
                var argMaxOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "ArgMaxOutput");
                var argMaxNode = ctx.CreateNode(opType, inputName, argMaxOutput, ctx.GetNodeName(opType), "");
                argMaxNode.AddAttribute("keepdims", 1);
                argMaxNode.AddAttribute("axis", 1);
 
                opType = "Add";
                var one = ctx.AddInitializer(1);
                var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "AddOutput");
                var addNode = ctx.CreateNode(opType, new[] { argMaxOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");
 
                opType = "Cast";
                var castToUint32Node = ctx.CreateNode(opType, addOutput, outputNames[0], ctx.GetNodeName(opType), "");
                var t2 = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
                castToUint32Node.AddAttribute("to", t2);
 
                opType = "Max";
                ctx.CreateNode(opType, inputName, outputNames[1], ctx.GetNodeName(opType), "");
            }
        }
 
        private sealed class ImplRaw : ImplBase
        {
            public override DataViewType InputType { get; }
            public override IValueMapper[] Predictors { get; }
            public override bool CanSavePfa { get; }
 
            internal ImplRaw(TScalarPredictor[] predictors)
            {
                Contracts.CheckNonEmpty(predictors, nameof(predictors));
 
                Predictors = new IValueMapper[predictors.Length];
                VectorDataViewType inputType = null;
                for (int i = 0; i < predictors.Length; i++)
                {
                    var vm = predictors[i] as IValueMapper;
                    Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface");
                    Predictors[i] = vm;
                }
                CanSavePfa = Predictors.All(m => (m as ISingleCanSavePfa)?.CanSavePfa == true);
                Contracts.AssertValue(inputType);
                InputType = inputType;
                OutputFormula = OutputFormula.Raw;
            }
 
            public override ValueMapper<VBuffer<float>, VBuffer<float>> GetMapper()
            {
                var maps = new ValueMapper<VBuffer<float>, float>[Predictors.Length];
                for (int i = 0; i < Predictors.Length; i++)
                    maps[i] = Predictors[i].GetMapper<VBuffer<float>, float>();
 
                var buffer = new float[maps.Length];
                return
                    (in VBuffer<float> src, ref VBuffer<float> dst) =>
                    {
                        int inputSize = InputType.GetVectorSize();
                        if (inputSize > 0)
                            Contracts.Check(src.Length == inputSize);
 
                        var tmp = src;
                        Parallel.For(0, maps.Length, i => maps[i](in tmp, ref buffer[i]));
 
                        var editor = VBufferEditor.Create(ref dst, maps.Length);
                        buffer.CopyTo(editor.Values);
                        dst = editor.Commit();
                    };
            }
 
            public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
            {
                Contracts.CheckValue(ctx, nameof(ctx));
                Contracts.CheckValue(input, nameof(input));
                Contracts.Assert(CanSavePfa);
 
                JArray rootObjects = new JArray();
                for (int i = 0; i < Predictors.Length; ++i)
                {
                    var pred = (ISingleCanSavePfa)Predictors[i];
                    Contracts.Assert(pred.CanSavePfa);
                    rootObjects.Add(ctx.DeclareVar(null, pred.SaveAsPfa(ctx, input)));
                }
                JObject jobj = null;
                return jobj.AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)).AddReturn("new", rootObjects);
            }
 
            public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
            {
                const int minimumOpSetVersion = 9;
                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
                var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false);
 
                string opType = "Concat";
                var type = new VectorDataViewType(NumberDataViewType.Single, probabilityOutputs.Length);
                var concatOutput = ctx.AddIntermediateVariable(type, "ConcatOutputRaw");
                var concatNode = ctx.CreateNode(opType, probabilityOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), "");
                concatNode.AddAttribute("axis", 1);
 
                base.SaveAsOnnxPostProcess(ctx, concatOutput, outputNames);
 
                return true;
 
            }
        }
 
        private sealed class ImplDist : ImplBase
        {
            private readonly IValueMapperDist[] _mappers;
            public override DataViewType InputType { get; }
            public override IValueMapper[] Predictors => _mappers;
            public override bool CanSavePfa { get; }
 
            internal ImplDist(IValueMapperDist[] predictors)
            {
                Contracts.Check(Utils.Size(predictors) > 0);
 
                _mappers = new IValueMapperDist[predictors.Length];
                VectorDataViewType inputType = null;
                for (int i = 0; i < predictors.Length; i++)
                {
                    var vm = predictors[i];
                    Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface");
                    _mappers[i] = vm;
                }
                CanSavePfa = Predictors.All(m => (m as IDistCanSavePfa)?.CanSavePfa == true);
                Contracts.AssertValue(inputType);
                InputType = inputType;
                OutputFormula = OutputFormula.ProbabilityNormalization;
            }
 
            private bool IsValid(IValueMapperDist mapper, ref VectorDataViewType inputType)
            {
                return base.IsValid(mapper, ref inputType) && mapper.DistType == NumberDataViewType.Single;
            }
 
            /// <summary>
            /// Each predictor produces a probability of a class. All classes' probabilities are normalized so that
            /// their sum is one.
            /// </summary>
            public override ValueMapper<VBuffer<float>, VBuffer<float>> GetMapper()
            {
                var maps = new ValueMapper<VBuffer<float>, float, float>[Predictors.Length];
                for (int i = 0; i < Predictors.Length; i++)
                    maps[i] = _mappers[i].GetMapper<VBuffer<float>, float, float>();
 
                var buffer = new float[maps.Length];
                return
                    (in VBuffer<float> src, ref VBuffer<float> dst) =>
                    {
                        int inputSize = InputType.GetVectorSize();
                        if (inputSize > 0)
                            Contracts.Check(src.Length == inputSize);
 
                        var tmp = src;
                        Parallel.For(0, maps.Length,
                            i =>
                            {
                                float score = 0;
                                // buffer[i] is the probability of the i-th class.
                                // score is the raw prediction score.
                                maps[i](in tmp, ref score, ref buffer[i]);
                            });
 
                        // buffer[i] is the probability of the i-th class.
                        // score is the raw prediction score.
                        NormalizeSumToOne(buffer, maps.Length);
 
                        var editor = VBufferEditor.Create(ref dst, maps.Length);
                        buffer.CopyTo(editor.Values);
                        dst = editor.Commit();
                    };
            }
 
            private void NormalizeSumToOne(float[] output, int count)
            {
                // Clamp to zero and normalize.
                Double sum = 0;
                for (int i = 0; i < count; i++)
                {
                    var value = output[i];
                    if (float.IsNaN(value))
                        continue;
 
                    if (value >= 0)
                        sum += value;
                    else
                        output[i] = 0;
                }
 
                if (sum > 0)
                {
                    for (int i = 0; i < count; i++)
                        output[i] = (float)(output[i] / sum);
                }
            }
 
            public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
            {
                Contracts.CheckValue(ctx, nameof(ctx));
                Contracts.CheckValue(input, nameof(input));
                Contracts.Assert(CanSavePfa);
 
                JArray rootObjects = new JArray();
                for (int i = 0; i < Predictors.Length; ++i)
                {
                    var pred = (IDistCanSavePfa)Predictors[i];
                    Contracts.Assert(pred.CanSavePfa);
                    pred.SaveAsPfa(ctx, input, null, out JToken scoreToken, null, out JToken probToken);
                    rootObjects.Add(probToken);
                }
                JObject jobj = null;
                var rootResult = jobj.AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)).AddReturn("new", rootObjects);
                var resultVar = ctx.DeclareVar(null, rootResult);
                var factorVar = ctx.DeclareVar(null, PfaUtils.Call("/", 1.0, PfaUtils.Call("a.sum", resultVar)));
                return PfaUtils.Call("la.scale", resultVar, factorVar);
            }
 
            public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
            {
                Contracts.Assert(outputNames.Length >= 2);
                const int minimumOpSetVersion = 9;
                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
 
                string opType;
                var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true);
 
                opType = "Sum";
                var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScores");
                ctx.CreateNode(opType, probabilityOutputs, new[] { sumOutput }, ctx.GetNodeName(opType), "");
 
                opType = "Cast";
                var castOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "CastOutput");
                var castNode = ctx.CreateNode(opType, sumOutput, castOutput, ctx.GetNodeName(opType), "");
                var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();
                castNode.AddAttribute("to", t);
 
                opType = "Not";
                var notOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZero");
                ctx.CreateNode(opType, castOutput, notOutput, ctx.GetNodeName(opType), "");
 
                opType = "Cast";
                var castIsZeroSumToFloat = ctx.AddIntermediateVariable(NumberDataViewType.Single, "IsSumZeroAsFloat");
                var castIsZeroSumToFloatNode = ctx.CreateNode(opType, notOutput, castIsZeroSumToFloat, ctx.GetNodeName(opType), "");
                var t1 = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
                castIsZeroSumToFloatNode.AddAttribute("to", t1);
 
                opType = "Sum";
                var sumOutputNonZero = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScoresNonZero");
                ctx.CreateNode(opType, new[] { sumOutput, castIsZeroSumToFloat },
                    new[] { sumOutputNonZero }, ctx.GetNodeName(opType), "");
 
                string[] divOutputs = new string[Predictors.Length];
                for (int i = 0; i < Predictors.Length; i++)
                {
                    opType = "Div";
                    divOutputs[i] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"DivOutput_{i}");
                    ctx.CreateNode(opType, new[] { probabilityOutputs[i], sumOutputNonZero }, new[] { divOutputs[i] }, ctx.GetNodeName(opType), "");
                }
 
                opType = "Concat";
                var type = new VectorDataViewType(NumberDataViewType.Single, divOutputs.Length);
                var concatOutput = ctx.AddIntermediateVariable(type, "ConcatOutputDist");
                var concatNode = ctx.CreateNode(opType, divOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), "");
                concatNode.AddAttribute("axis", 1);
 
                base.SaveAsOnnxPostProcess(ctx, concatOutput, outputNames);
 
                return true;
            }
        }
 
        private sealed class ImplSoftmax : ImplBase
        {
            public override DataViewType InputType { get; }
            public override IValueMapper[] Predictors { get; }
            public override bool CanSavePfa { get; }
 
            internal ImplSoftmax(TScalarPredictor[] predictors)
            {
                Contracts.CheckNonEmpty(predictors, nameof(predictors));
 
                Predictors = new IValueMapper[predictors.Length];
                VectorDataViewType inputType = null;
                for (int i = 0; i < predictors.Length; i++)
                {
                    var vm = predictors[i] as IValueMapper;
                    Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface");
                    Predictors[i] = vm;
                }
                CanSavePfa = false;
                Contracts.AssertValue(inputType);
                InputType = inputType;
                OutputFormula = OutputFormula.Softmax;
            }
 
            public override ValueMapper<VBuffer<float>, VBuffer<float>> GetMapper()
            {
                var maps = new ValueMapper<VBuffer<float>, float>[Predictors.Length];
                for (int i = 0; i < Predictors.Length; i++)
                    maps[i] = Predictors[i].GetMapper<VBuffer<float>, float>();
 
                var buffer = new float[maps.Length];
                return
                    (in VBuffer<float> src, ref VBuffer<float> dst) =>
                    {
                        int inputSize = InputType.GetVectorSize();
                        if (inputSize > 0)
                            Contracts.Check(src.Length == inputSize);
 
                        var tmp = src;
                        Parallel.For(0, maps.Length, i => maps[i](in tmp, ref buffer[i]));
                        NormalizeSoftmax(buffer, maps.Length);
 
                        var editor = VBufferEditor.Create(ref dst, maps.Length);
                        buffer.CopyTo(editor.Values);
                        dst = editor.Commit();
                    };
            }
 
            private void NormalizeSoftmax(float[] scores, int count)
            {
                double sum = 0;
                var score = new double[count];
 
                for (int i = 0; i < count; i++)
                {
                    score[i] = Math.Exp(scores[i]);
                    sum += score[i];
                }
 
                for (int i = 0; i < count; i++)
                    scores[i] = (float)(score[i] / sum);
            }
 
            public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
            {
                throw new NotImplementedException("Softmax's PFA exporter is not implemented yet.");
            }
 
            public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
            {
                Contracts.Assert(outputNames.Length >= 2);
                const int minimumOpSetVersion = 9;
                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
 
                var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false);
 
                string opType;
                opType = "Concat";
                var type = new VectorDataViewType(NumberDataViewType.Single, probabilityOutputs.Length);
                var concatOutput = ctx.AddIntermediateVariable(type, "ConcatOutputSoftMax");
                var concatNode = ctx.CreateNode(opType, probabilityOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), "");
                concatNode.AddAttribute("axis", 1);
 
                opType = "Exp";
                var expOutput = ctx.AddIntermediateVariable(type, "ExpOutput");
                var expNode = ctx.CreateNode(opType, concatOutput, expOutput, ctx.GetNodeName(opType), "");
 
                opType = "ReduceSum";
                var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOutput");
                var sumNode = ctx.CreateNode(opType, expOutput, sumOutput, ctx.GetNodeName(opType), "");
                sumNode.AddAttribute("keepdims", 1);
                long[] list = { 1 };
                sumNode.AddAttribute("axes", list);
 
                opType = "Div";
                var divOutput = ctx.AddIntermediateVariable(type, "DivOutput");
                var divNode = ctx.CreateNode(opType, new[] { expOutput, sumOutput }, new[] { divOutput }, ctx.GetNodeName(opType), "");
 
                base.SaveAsOnnxPostProcess(ctx, divOutput, outputNames);
 
                return true;
            }
        }
    }
}