File: Standard\MulticlassClassification\MulticlassNaiveBayesTrainer.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
 
[assembly: LoadableClass(NaiveBayesMulticlassTrainer.Summary, typeof(NaiveBayesMulticlassTrainer), typeof(NaiveBayesMulticlassTrainer.Options),
    new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) },
    NaiveBayesMulticlassTrainer.UserName,
    NaiveBayesMulticlassTrainer.LoadName,
    NaiveBayesMulticlassTrainer.ShortName)]
 
[assembly: LoadableClass(typeof(NaiveBayesMulticlassModelParameters), null, typeof(SignatureLoadModel),
    "Multi Class Naive Bayes predictor", NaiveBayesMulticlassModelParameters.LoaderSignature)]
 
[assembly: LoadableClass(typeof(void), typeof(NaiveBayesMulticlassTrainer), null, typeof(SignatureEntryPointModule), NaiveBayesMulticlassTrainer.LoadName)]
 
namespace Microsoft.ML.Trainers
{
    /// <summary>
    /// The <see cref="IEstimator{TTransformer}"/> for training a multiclass Naive Bayes model that supports binary feature values.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    /// To create this trainer, use [NaiveBayes](xref:Microsoft.ML.StandardTrainersCatalog.NaiveBayes(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String)).
    ///
    /// [!include[io](~/../docs/samples/docs/api-reference/io-columns-multiclass-classification.md)]
    ///
    /// ### Trainer Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Machine learning task | Multiclass classification |
    /// | Is normalization required? | Yes |
    /// | Is caching required? | No |
    /// | Required NuGet in addition to Microsoft.ML | None |
    /// | Exportable to ONNX | Yes |
    ///
    /// ### Training Algorithm Details
    /// [Naive Bayes](https://en.wikipedia.org/wiki/Naive_Bayes_classifier)
    /// is a probabilistic classifier that can be used for multiclass problems.
    /// Using Bayes' theorem, the conditional probability for a sample belonging to a class
    /// can be calculated based on the sample count for each feature combination groups.
    /// However, Naive Bayes Classifier is feasible only if the number of features and
    /// the values each feature can take is relatively small.
    /// It assumes independence among the presence of features in a class even though
    /// they may be dependent on each other.
    /// This multi-class trainer accepts "binary" feature values of type float:
    /// feature values that are greater than zero are treated as `true` and feature values
    /// that are less or equal to 0 are treated as `false`.
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]>
    /// </format>
    /// </remarks>
    /// <seealso cref="StandardTrainersCatalog.NaiveBayes(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String)"/>
    public sealed class NaiveBayesMulticlassTrainer : TrainerEstimatorBase<MulticlassPredictionTransformer<NaiveBayesMulticlassModelParameters>, NaiveBayesMulticlassModelParameters>
    {
        internal const string LoadName = "MultiClassNaiveBayes";
        internal const string UserName = "Multiclass Naive Bayes";
        internal const string ShortName = "MNB";
        internal const string Summary = "Trains a multiclass Naive Bayes predictor that supports binary feature values.";
 
        internal sealed class Options : TrainerInputBaseWithLabel
        {
        }
 
        /// <summary> Return the type of prediction task.</summary>
        private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
 
        private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
 
        /// <summary>
        /// Auxiliary information about the trainer in terms of its capabilities
        /// and requirements.
        /// </summary>
        public override TrainerInfo Info => _info;
 
        /// <summary>
        /// Initializes a new instance of <see cref="NaiveBayesMulticlassTrainer"/>
        /// </summary>
        /// <param name="env">The environment to use.</param>
        /// <param name="labelColumn">The name of the label column.</param>
        /// <param name="featureColumn">The name of the feature column.</param>
        internal NaiveBayesMulticlassTrainer(IHostEnvironment env,
            string labelColumn = DefaultColumnNames.Label,
            string featureColumn = DefaultColumnNames.Features)
            : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(featureColumn),
                  TrainerUtils.MakeU4ScalarColumn(labelColumn))
        {
            Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
            Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
        }
 
        /// <summary>
        /// Initializes a new instance of <see cref="NaiveBayesMulticlassTrainer"/>
        /// </summary>
        internal NaiveBayesMulticlassTrainer(IHostEnvironment env, Options options)
            : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName),
                  TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName))
        {
            Host.CheckValue(options, nameof(options));
        }
 
        private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
            Contracts.Assert(success);
 
            var predLabelMetadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
                .Concat(AnnotationUtils.GetTrainerOutputAnnotation()));
 
            return new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, predLabelMetadata)
            };
        }
 
        private protected override MulticlassPredictionTransformer<NaiveBayesMulticlassModelParameters> MakeTransformer(NaiveBayesMulticlassModelParameters model, DataViewSchema trainSchema)
            => new MulticlassPredictionTransformer<NaiveBayesMulticlassModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
 
        private protected override NaiveBayesMulticlassModelParameters TrainModelCore(TrainContext context)
        {
            Host.CheckValue(context, nameof(context));
            var data = context.TrainingSet;
            Host.Check(data.Schema.Label.HasValue, "Missing Label column");
            var labelCol = data.Schema.Label.Value;
            Host.Check(labelCol.Type == NumberDataViewType.Single || labelCol.Type is KeyDataViewType,
                "Invalid type for Label column, only floats and known-size keys are supported");
 
            Host.Check(data.Schema.Feature.HasValue, "Missing Feature column");
            int featureCount;
            data.CheckFeatureFloatVector(out featureCount);
            int labelCount = 0;
            if (labelCol.Type is KeyDataViewType labelKeyType)
                labelCount = labelKeyType.GetCountAsInt32(Host);
 
            long[] labelHistogram = new long[labelCount];
            long[][] featureHistogram = new long[labelCount][];
            using (var pch = Host.StartProgressChannel("Multi Class Naive Bayes training"))
            using (var ch = Host.Start("Training"))
            using (var cursor = new MulticlassLabelCursor(labelCount, data, CursOpt.Features | CursOpt.Label))
            {
                int examplesProcessed = 0;
                pch.SetHeader(new ProgressHeader(new[] { "Examples Processed" }, new[] { "count" }), e =>
                   {
                       e.SetProgress(0, examplesProcessed, int.MaxValue);
                   });
 
                while (cursor.MoveNext())
                {
                    if (cursor.Row.Position > int.MaxValue)
                    {
                        ch.Warning("Stopping training because maximum number of rows have been traversed");
                        break;
                    }
 
                    int size = cursor.Label + 1;
                    Utils.EnsureSize(ref labelHistogram, size);
                    Utils.EnsureSize(ref featureHistogram, size);
                    if (featureHistogram[cursor.Label] == null)
                        featureHistogram[cursor.Label] = new long[featureCount];
                    labelHistogram[cursor.Label] += 1;
                    labelCount = labelCount < size ? size : labelCount;
 
                    var featureValues = cursor.Features.GetValues();
                    if (cursor.Features.IsDense)
                    {
                        for (int i = 0; i < featureValues.Length; i += 1)
                        {
                            if (featureValues[i] > 0)
                                featureHistogram[cursor.Label][i] += 1;
                        }
                    }
                    else
                    {
                        var featureIndices = cursor.Features.GetIndices();
                        for (int i = 0; i < featureValues.Length; i += 1)
                        {
                            if (featureValues[i] > 0)
                                featureHistogram[cursor.Label][featureIndices[i]] += 1;
                        }
                    }
 
                    examplesProcessed += 1;
                }
            }
 
            Array.Resize(ref labelHistogram, labelCount);
            Array.Resize(ref featureHistogram, labelCount);
            return new NaiveBayesMulticlassModelParameters(Host, labelHistogram, featureHistogram, featureCount);
        }
 
        [TlcModule.EntryPoint(Name = "Trainers.NaiveBayesClassifier",
            Desc = "Train a MulticlassNaiveBayesTrainer.",
            UserName = UserName,
            ShortName = ShortName)]
        internal static CommonOutputs.MulticlassClassificationOutput TrainMulticlassNaiveBayesTrainer(IHostEnvironment env, Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainMultiClassNaiveBayes");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            return TrainerEntryPointsUtils.Train<Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
                () => new NaiveBayesMulticlassTrainer(host, input),
                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName));
        }
    }
 
    /// <summary>
    /// Model parameters for <see cref="NaiveBayesMulticlassTrainer"/>.
    /// </summary>
    public sealed class NaiveBayesMulticlassModelParameters :
        ModelParametersBase<VBuffer<float>>,
        IValueMapper,
        ISingleCanSaveOnnx
    {
        internal const string LoaderSignature = "MultiClassNaiveBayesPred";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "MNABYPRD",
                //verWrittenCur: 0x00010001, // Initial
                verWrittenCur: 0x00010002, // Histograms are of type long
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(NaiveBayesMulticlassModelParameters).Assembly.FullName);
        }
 
        private readonly long[] _labelHistogram;
        private readonly long[][] _featureHistogram;
        private readonly double[] _absentFeaturesLogProb;
        private readonly long _totalTrainingCount;
        private readonly int _labelCount;
        private readonly int _featureCount;
        private readonly VectorDataViewType _inputType;
        private readonly VectorDataViewType _outputType;
 
        /// <summary> Return the type of prediction task.</summary>
        private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
 
        DataViewType IValueMapper.InputType => _inputType;
 
        DataViewType IValueMapper.OutputType => _outputType;
 
        bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
 
        /// <summary>
        /// Get the label histogram.
        /// </summary>
        [Obsolete("This API is deprecated, please use GetLabelHistogramLong() which returns _labelHistogram " +
            "with type IReadOnlyList<long> to avoid overflow errors with large datasets.", true)]
        public IReadOnlyList<int> GetLabelHistogram() => Array.ConvertAll(_labelHistogram, x => (int)x);
 
        /// <summary>
        /// Get the label histogram with generic type long.
        /// </summary>
        public IReadOnlyList<long> GetLabelHistogramLong() => _labelHistogram;
 
        /// <summary>
        /// Get the feature histogram.
        /// </summary>
        [Obsolete("This API is deprecated, please use GetFeatureHistogramLong() which returns _featureHistogram " +
            "with type IReadOnlyList<long> to avoid overflow errors with large datasets.", true)]
        public IReadOnlyList<IReadOnlyList<int>> GetFeatureHistogram() => Array.ConvertAll(_featureHistogram, x => Array.ConvertAll(x, y => (int)y));
 
        /// <summary>
        /// Get the feature histogram with generic type long.
        /// </summary>
        public IReadOnlyList<IReadOnlyList<long>> GetFeatureHistogramLong() => _featureHistogram;
 
        /// <summary>
        /// Instantiates new model parameters from trained model.
        /// </summary>
        /// <param name="env">The host environment.</param>
        /// <param name="labelHistogram">The histogram of labels.</param>
        /// <param name="featureHistogram">The feature histogram.</param>
        /// <param name="featureCount">The number of features.</param>
        internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, long[] labelHistogram, long[][] featureHistogram, int featureCount)
            : base(env, LoaderSignature)
        {
            Host.AssertValue(labelHistogram);
            Host.AssertValue(featureHistogram);
            Host.Assert(labelHistogram.Length == featureHistogram.Length);
            Host.Assert(featureHistogram.All(h => h == null || h.Length == featureCount));
            _labelHistogram = labelHistogram;
            _featureHistogram = featureHistogram;
            _totalTrainingCount = _labelHistogram.Sum();
            _labelCount = _labelHistogram.Length;
            _featureCount = featureCount;
            _absentFeaturesLogProb = CalculateAbsentFeatureLogProbabilities(_labelHistogram, _featureHistogram, _featureCount);
            _inputType = new VectorDataViewType(NumberDataViewType.Single, _featureCount);
            _outputType = new VectorDataViewType(NumberDataViewType.Single, _labelCount);
        }
 
        /// <remarks>
        /// The unit test TestEntryPoints.LoadEntryPointModel() exercises the ReadIntArrary(int size) codepath below
        /// as its ctx.Header.ModelVerWritten is 0x00010001, and the persistent model that gets loaded and executed
        /// for this unit test is located at test\data\backcompat\ep_model3.zip/>
        /// </remarks>
        private NaiveBayesMulticlassModelParameters(IHostEnvironment env, ModelLoadContext ctx)
            : base(env, LoaderSignature, ctx)
        {
            // *** Binary format ***
            // int: _labelCount (read during reading of _labelHistogram in ReadLongArray())
            // long[_labelCount]: _labelHistogram
            // int: _featureCount
            // long[_labelCount][_featureCount]: _featureHistogram
            // int[_labelCount]: _absentFeaturesLogProb
            if (ctx.Header.ModelVerWritten >= 0x00010002)
                _labelHistogram = ctx.Reader.ReadLongArray() ?? new long[0];
            else
            {
                _labelHistogram = Array.ConvertAll(ctx.Reader.ReadIntArray() ?? new int[0], x => (long)x);
            }
            _labelCount = _labelHistogram.Length;
 
            foreach (int labelCount in _labelHistogram)
                Host.CheckDecode(labelCount >= 0);
 
            _featureCount = ctx.Reader.ReadInt32();
            Host.CheckDecode(_featureCount >= 0);
            _featureHistogram = new long[_labelCount][];
            for (int iLabel = 0; iLabel < _labelCount; iLabel += 1)
            {
                if (_labelHistogram[iLabel] > 0)
                {
                    if (ctx.Header.ModelVerWritten >= 0x00010002)
                        _featureHistogram[iLabel] = ctx.Reader.ReadLongArray(_featureCount);
                    else
                        _featureHistogram[iLabel] = Array.ConvertAll(ctx.Reader.ReadIntArray(_featureCount) ?? new int[0], x => (long)x);
                    for (int iFeature = 0; iFeature < _featureCount; iFeature += 1)
                        Host.CheckDecode(_featureHistogram[iLabel][iFeature] >= 0);
                }
            }
 
            _absentFeaturesLogProb = ctx.Reader.ReadDoubleArray(_labelCount);
            _totalTrainingCount = _labelHistogram.Sum();
            _inputType = new VectorDataViewType(NumberDataViewType.Single, _featureCount);
            _outputType = new VectorDataViewType(NumberDataViewType.Single, _labelCount);
        }
 
        internal static NaiveBayesMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            return new NaiveBayesMulticlassModelParameters(env, ctx);
        }
 
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // int: _labelCount
            // long[_labelCount]: _labelHistogram
            // int: _featureCount
            // long[_labelCount][_featureCount]: _featureHistogram
            // int[_labelCount]: _absentFeaturesLogProb
            ctx.Writer.Write(_labelCount);
            ctx.Writer.WriteLongStream(_labelHistogram);
            ctx.Writer.Write(_featureCount);
            for (int i = 0; i < _labelCount; i += 1)
            {
                if (_labelHistogram[i] > 0)
                    ctx.Writer.WriteLongStream(_featureHistogram[i]);
            }
 
            ctx.Writer.WriteDoublesNoCount(_absentFeaturesLogProb.AsSpan(0, _labelCount));
        }
 
        private static double[] CalculateAbsentFeatureLogProbabilities(long[] labelHistogram, long[][] featureHistogram, int featureCount)
        {
            int labelCount = labelHistogram.Length;
            double[] absentFeaturesLogProb = new double[labelCount];
            for (int iLabel = 0; iLabel < labelHistogram.Length; iLabel += 1)
            {
                if (labelHistogram[iLabel] > 0)
                {
                    double logProb = 0;
                    for (int iFeature = 0; iFeature < featureCount; iFeature += 1)
                    {
                        long labelOccuranceCount = labelHistogram[iLabel];
                        logProb +=
                            Math.Log(1 + ((double)labelOccuranceCount - featureHistogram[iLabel][iFeature])) -
                            Math.Log(labelOccuranceCount + labelCount);
                    }
 
                    absentFeaturesLogProb[iLabel] = logProb;
                }
            }
 
            return absentFeaturesLogProb;
        }
 
        ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
        {
            Host.Check(typeof(TIn) == typeof(VBuffer<float>));
            Host.Check(typeof(TOut) == typeof(VBuffer<float>));
 
            ValueMapper<VBuffer<float>, VBuffer<float>> del = Map;
            return (ValueMapper<TIn, TOut>)(Delegate)del;
        }
 
        /// <summary>
        /// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/>
        /// </summary>
        bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
        {
            const int minimumOpSetVersion = 9;
            ctx.CheckOpSetVersion(minimumOpSetVersion, "MulticlassNaiveBayes");
 
            float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length];
            float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];
 
            for (int i = 0; i < _featureHistogram.Length; i++)
            {
                Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length);
            }
            for (int i = 0; i < _featureHistogram[0].Length; i++)
            {
                Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length);
            }
 
            var one = ctx.AddInitializer(1.0f, "one");
            var oneInt = ctx.AddInitializer(1, typeof(int), "oneInt");
            var zero = ctx.AddInitializer(0.0f, "zero");
            var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount");
            var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
            var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram");
 
            var featureHistogramName = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram");
            var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
            var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");
 
            var typeOne = new VectorDataViewType(NumberDataViewType.Single, 1);
            var typeFea = new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length);
            var typeLabelByFea = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, _featureHistogram[0].Length);
            var typeLabelByOne = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, 1);
 
            var greaterOutput = ctx.AddIntermediateVariable(new VectorDataViewType(BooleanDataViewType.Instance, _featureHistogram[0].Length), "greaterOutput");
            var opType = "Greater";
            ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");
 
            opType = "Cast";
            var castOutput = ctx.AddIntermediateVariable(typeFea, "CastOutput");
            var node = ctx.CreateNode(opType, greaterOutput, castOutput, ctx.GetNodeName(opType), "");
            var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);
 
            opType = "ExpandDims";
            var isFeaturePresent = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1, _featureHistogram[0].Length), "isFeaturePresent");
            ctx.CreateNode(opType, new[] { castOutput, oneInt }, new[] { isFeaturePresent }, ctx.GetNodeName(opType), "com.microsoft");
 
            //initialize logProb
            opType = "Div";
            var divOutput = ctx.AddIntermediateVariable(typeOne, "DivOutput");
            ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");
 
            opType = "Log";
            var logOutput = ctx.AddIntermediateVariable(typeOne, "LogOutput");
            ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");
 
            //log1
            opType = "Sum";
            var sumOutput = ctx.AddIntermediateVariable(_inputType, "SumOutput");
            ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
 
            var logOutput1 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");
            LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);
 
            //log2
            opType = "Transpose";
            var labelHistogramTrans = ctx.AddIntermediateVariable(typeFea, "Transpose");
            ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");
 
            opType = "Sub";
            var absentFeatureCount = ctx.AddIntermediateVariable(typeFea, "AbsentFeatureCounts");
            ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");
 
            opType = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
            ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
 
            var logOutput2 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");
            LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);
 
            //log3
            opType = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
            ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
 
            var logOutput3 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");
            LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);
 
            //result
            opType = "Sub";
            var logProb = ctx.AddIntermediateVariable(typeLabelByFea, "LogProb");
            ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");
 
            opType = "Sub";
            var absentFeatureLogProb = ctx.AddIntermediateVariable(typeLabelByFea, "AbsentFeatureLogProb");
            ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");
 
            opType = "ReduceSum";
            var logProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");
            node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
            long[] list = { 2 };
            node.AddAttribute("axes", list);
 
            opType = "ReduceSum";
            var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");
            node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
            node.AddAttribute("axes", list);
 
            opType = "Cast";
            castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastOutput");
            node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
            t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);
 
            opType = "Sub";
            var subOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SubOutput");
            ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");
 
            opType = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SumOutput");
            ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
 
            opType = "Squeeze";
            var squeezeNode = ctx.CreateNode(opType, sumOutput, outputNames[1], ctx.GetNodeName(opType), "");
            squeezeNode.AddAttribute("axes", new long[] { 2 });
 
            opType = "ArgMax";
            var scoreIndex = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, 1), "ScoreIndex");
            node = ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
            node.AddAttribute("axis", 1);
            node.AddAttribute("keepdims", 0);
 
            opType = "Cast";
            castOutput = ctx.AddIntermediateVariable(typeOne, "CastOutput");
            node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
            t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);
 
            //log3
            opType = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeOne, "SumOutput");
            ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
 
            opType = "Cast";
            node = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), "");
            t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
            node.AddAttribute("to", t);
 
            return true;
        }
 
        private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output)
        {
            var opType = "Log";
            var logOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length), "LogOutput");
            ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), "");
 
            opType = "Mul";
            ctx.CreateNode(opType, new[] { logOutput, isFeaturePresent }, new[] { output }, ctx.GetNodeName(opType), "");
        }
 
        private void ComputeLabelProbabilityFromFeature(double labelOccurrenceCount, int labelIndex, int featureIndex,
            float featureValue, ref double logProb, ref double absentFeatureLogProb)
        {
            if (featureValue <= 0)
                return;
 
            double featureCount = _featureHistogram[labelIndex][featureIndex];
            double absentFeatureCount = labelOccurrenceCount - featureCount;
            Host.Assert(featureCount >= 0);
            logProb += Math.Log(featureCount + 1) - Math.Log(labelOccurrenceCount + _labelCount);
            absentFeatureLogProb += Math.Log(absentFeatureCount + 1) - Math.Log(labelOccurrenceCount + _labelCount);
        }
 
        private void Map(in VBuffer<float> src, ref VBuffer<float> dst)
        {
            Host.Check(src.Length == _featureCount, "Invalid number of features passed.");
 
            var srcValues = src.GetValues();
            var srcIndices = src.GetIndices();
 
            var editor = VBufferEditor.Create(ref dst, _labelCount);
            Span<float> labelScores = editor.Values;
            for (int iLabel = 0; iLabel < _labelCount; iLabel += 1)
            {
                double labelOccurrenceCount = _labelHistogram[iLabel];
                double logProb = Math.Log(labelOccurrenceCount / _totalTrainingCount);
                double absentFeatureLogProb = 0;
                if (_labelHistogram[iLabel] > 0)
                {
                    if (src.IsDense)
                    {
                        for (int iFeature = 0; iFeature < srcValues.Length; iFeature += 1)
                        {
                            ComputeLabelProbabilityFromFeature(labelOccurrenceCount, iLabel, iFeature,
                                srcValues[iFeature], ref logProb, ref absentFeatureLogProb);
                        }
                    }
                    else
                    {
                        for (int iFeature = 0; iFeature < srcValues.Length; iFeature += 1)
                        {
                            ComputeLabelProbabilityFromFeature(labelOccurrenceCount, iLabel, srcIndices[iFeature],
                                srcValues[iFeature], ref logProb, ref absentFeatureLogProb);
                        }
                    }
                }
 
                labelScores[iLabel] =
                    (float)(logProb + (_absentFeaturesLogProb[iLabel] - absentFeatureLogProb));
            }
 
            dst = editor.Commit();
        }
    }
}