|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
using Newtonsoft.Json.Linq;
[assembly: LoadableClass(typeof(LbfgsMaximumEntropyMulticlassTrainer), typeof(LbfgsMaximumEntropyMulticlassTrainer.Options),
new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) },
LbfgsMaximumEntropyMulticlassTrainer.UserNameValue,
LbfgsMaximumEntropyMulticlassTrainer.LoadNameValue,
"MulticlassLogisticRegressionPredictorNew",
LbfgsMaximumEntropyMulticlassTrainer.ShortName,
"multilr")]
[assembly: LoadableClass(typeof(MaximumEntropyModelParameters), null, typeof(SignatureLoadModel),
"Multiclass LR Executor",
MaximumEntropyModelParameters.LoaderSignature)]
[assembly: LoadableClass(typeof(LinearMulticlassModelParameters), null, typeof(SignatureLoadModel),
"Multiclass LR No Calib",
LinearMulticlassModelParameters.LoaderSignature)]
[assembly: LoadableClass(typeof(void), typeof(LbfgsMaximumEntropyMulticlassTrainer), null, typeof(SignatureEntryPointModule), LbfgsMaximumEntropyMulticlassTrainer.LoadNameValue)]
namespace Microsoft.ML.Trainers
{
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> to predict a target using a maximum entropy multiclass classifier trained with L-BFGS method.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [LbfgsMaximumEntropy](xref:Microsoft.ML.StandardTrainersCatalog.LbfgsMaximumEntropy(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,System.Single,System.Single,System.Single,System.Int32,System.Boolean))
/// or [LbfgsMaximumEntropy(Options)](xref:Microsoft.ML.StandardTrainersCatalog.LbfgsMaximumEntropy(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,Microsoft.ML.Trainers.LbfgsMaximumEntropyMulticlassTrainer.Options)).
///
/// [!include[io](~/../docs/samples/docs/api-reference/io-columns-multiclass-classification.md)]
///
/// ### Trainer Characteristics
/// | | |
/// | -- | -- |
/// | Machine learning task | Multiclass classification |
/// | Is normalization required? | Yes |
/// | Is caching required? | No |
/// | Required NuGet in addition to Microsoft.ML | None |
/// | Exportable to ONNX | Yes |
///
/// ### Scoring Function
/// [Maximum entropy model](https://en.wikipedia.org/wiki/Multinomial_logistic_regression) is a generalization of linear [logistic regression](https://en.wikipedia.org/wiki/Logistic_regression).
/// The major difference between maximum entropy model and logistic regression is the number of classes supported in the considered classification problem.
/// Logistic regression is only for binary classification while maximum entropy model handles multiple classes.
/// See Section 1 in [this paper](https://www.csie.ntu.edu.tw/~cjlin/papers/maxent_dual.pdf) for a detailed introduction.
///
/// Assume that the number of classes is $m$ and number of features is $n$.
/// Maximum entropy model assigns the $c$-th class a coefficient vector $\textbf{w}\_c \in {\mathbb R}^n$ and a bias $b_c \in {\mathbb R}$, for $c=1,\dots,m$.
/// Given a feature vector $\textbf{x} \in {\mathbb R}^n$, the $c$-th class's score is $\hat{y}^c = \textbf{w}\_c^T \textbf{x} + b_c$.
/// The probability of $\textbf{x}$ belonging to class $c$ is defined by $\tilde{P}(c | \textbf{x}) = \frac{ e^{\hat{y}^c} }{ \sum\_{c' = 1}^m e^{\hat{y}^{c'}} }$.
/// Let $P(c, \textbf{ x})$ denote the joint probability of seeing $c$ and $\textbf{x}$.
/// The loss function minimized by this trainer is $-\sum\_{c = 1}^m P(c, \textbf{ x}) \log \tilde{P}(c | \textbf{x}) $, which is the negative [log-likelihood function](https://en.wikipedia.org/wiki/Likelihood_function#Log-likelihood).
///
/// ### Training Algorithm Details
/// The optimization technique implemented is based on [the limited memory Broyden-Fletcher-Goldfarb-Shanno method (L-BFGS)](https://en.wikipedia.org/wiki/Limited-memory_BFGS).
/// L-BFGS is a [quasi-Newtonian method](https://en.wikipedia.org/wiki/Quasi-Newton_method), which replaces the expensive computation of the Hessian matrix with an approximation but still enjoys a fast convergence rate like [Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method_in_optimization) where the full Hessian matrix is computed.
/// Since L-BFGS approximation uses only a limited amount of historical states to compute the next step direction, it is especially suited for problems with a high-dimensional feature vector.
/// The number of historical states is a user-specified parameter, using a larger number may lead to a better approximation of the Hessian matrix but also a higher computation cost per step.
///
/// [!include[io](~/../docs/samples/docs/api-reference/regularization-l1-l2.md)]
///
/// Check the See Also section for links to usage examples.
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.LbfgsMaximumEntropy(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, float, float, float, int, bool)"/>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.LbfgsMaximumEntropy(MulticlassClassificationCatalog.MulticlassClassificationTrainers, LbfgsMaximumEntropyMulticlassTrainer.Options)"/>
/// <seealso cref="Options"/>
public sealed class LbfgsMaximumEntropyMulticlassTrainer : LbfgsTrainerBase<LbfgsMaximumEntropyMulticlassTrainer.Options,
MulticlassPredictionTransformer<MaximumEntropyModelParameters>, MaximumEntropyModelParameters>
{
internal const string Summary = "Maximum entropy classification is a method in statistics used to predict the probabilities of parallel events. The model predicts the probabilities of parallel events by fitting data to a softmax function.";
internal const string LoadNameValue = "MultiClassLogisticRegression";
internal const string UserNameValue = "Multi-class Logistic Regression";
internal const string ShortName = "mlr";
/// <summary>
/// <see cref="Options"/> for <see cref="LbfgsMaximumEntropyMulticlassTrainer"/> as used in
/// <see cref="Microsoft.ML.StandardTrainersCatalog.LbfgsMaximumEntropy(MulticlassClassificationCatalog.MulticlassClassificationTrainers, LbfgsMaximumEntropyMulticlassTrainer.Options)"/>.
/// </summary>
public sealed class Options : OptionsBase
{
/// <summary>
/// If set to <value>true</value> training statistics will be generated at the end of training.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Show statistics of training examples.", ShortName = "stat, ShowTrainingStats", SortOrder = 50)]
public bool ShowTrainingStatistics = false;
}
private int _numClasses;
// The names for each label class, indexed by zero based class number.
// These label names are used for model saving in place of class number
// to make the model summary more user friendly. These names are populated
// in the CheckLabel() method.
// It could be null, if the label type is not a key type, or there is
// missing label name for some class.
private string[] _labelNames;
// The prior distribution of data.
// This array is of length equal to the number of classes.
// After training, it stores the total weights of training examples in each class.
private Double[] _prior;
private ModelStatisticsBase _stats;
private protected override int ClassCount => _numClasses;
/// <summary>
/// Initializes a new instance of <see cref="LbfgsMaximumEntropyMulticlassTrainer"/>.
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="weights">The name for the example weight column.</param>
/// <param name="enforceNoNegativity">Enforce non-negative weights.</param>
/// <param name="l1Weight">Weight of L1 regularizer term.</param>
/// <param name="l2Weight">Weight of L2 regularizer term.</param>
/// <param name="memorySize">Memory size for <see cref="LbfgsLogisticRegressionBinaryTrainer"/>. Low=faster, less accurate.</param>
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
internal LbfgsMaximumEntropyMulticlassTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
float l1Weight = Options.Defaults.L1Regularization,
float l2Weight = Options.Defaults.L2Regularization,
float optimizationTolerance = Options.Defaults.OptimizationTolerance,
int memorySize = Options.Defaults.HistorySize,
bool enforceNoNegativity = Options.Defaults.EnforceNonNegativity)
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
ShowTrainingStats = LbfgsTrainerOptions.ShowTrainingStatistics;
}
/// <summary>
/// Initializes a new instance of <see cref="LbfgsMaximumEntropyMulticlassTrainer"/>.
/// </summary>
internal LbfgsMaximumEntropyMulticlassTrainer(IHostEnvironment env, Options options)
: base(env, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName))
{
ShowTrainingStats = LbfgsTrainerOptions.ShowTrainingStatistics;
}
private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
private protected override void CheckLabel(RoleMappedData data)
{
Contracts.AssertValue(data);
// REVIEW: For floating point labels, this will make a pass over the data.
// Should we instead leverage the pass made by the LBFGS base class? Ideally, it wouldn't
// make a pass over the data...
data.CheckMulticlassLabel(out _numClasses);
// Initialize prior counts.
_prior = new Double[_numClasses];
// Try to get the label key values metedata.
var labelCol = data.Schema.Label.Value;
var labelMetadataType = labelCol.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type;
if (!(labelMetadataType is VectorDataViewType vecType && vecType.ItemType == TextDataViewType.Instance && vecType.Size == _numClasses))
{
_labelNames = null;
return;
}
VBuffer<ReadOnlyMemory<char>> labelNames = default;
labelCol.GetKeyValues(ref labelNames);
// If label names is not dense or contain NA or default value, then it follows that
// at least one class does not have a valid name for its label. If the label names we
// try to get from the metadata are not unique, we may also not use them in model summary.
// In both cases we set _labelNames to null and use the "Class_n", where n is the class number
// for model summary saving instead.
if (!labelNames.IsDense)
{
_labelNames = null;
return;
}
_labelNames = new string[_numClasses];
ReadOnlySpan<ReadOnlyMemory<char>> values = labelNames.GetValues();
// This hashset is used to verify the uniqueness of label names.
HashSet<string> labelNamesSet = new HashSet<string>();
for (int i = 0; i < _numClasses; i++)
{
ReadOnlyMemory<char> value = values[i];
if (value.IsEmpty)
{
_labelNames = null;
break;
}
var vs = values[i].ToString();
if (!labelNamesSet.Add(vs))
{
_labelNames = null;
break;
}
_labelNames[i] = vs;
Contracts.Assert(!string.IsNullOrEmpty(_labelNames[i]));
}
Contracts.Assert(_labelNames == null || _labelNames.Length == _numClasses);
}
//Override default termination criterion MeanRelativeImprovementCriterion with
private protected override Optimizer InitializeOptimizer(IChannel ch, FloatLabelCursor.Factory cursorFactory,
out VBuffer<float> init, out ITerminationCriterion terminationCriterion)
{
var opt = base.InitializeOptimizer(ch, cursorFactory, out init, out terminationCriterion);
// MeanImprovementCriterion:
// Terminates when the geometrically-weighted average improvement falls below the tolerance
terminationCriterion = new MeanImprovementCriterion(OptTol, 0.25f, MaxIterations);
return opt;
}
private protected override float AccumulateOneGradient(in VBuffer<float> feat, float label, float weight,
in VBuffer<float> x, ref VBuffer<float> grad, ref float[] scores)
{
if (Utils.Size(scores) < _numClasses)
scores = new float[_numClasses];
float bias = 0;
for (int c = 0, start = _numClasses; c < _numClasses; c++, start += NumFeatures)
{
x.GetItemOrDefault(c, ref bias);
scores[c] = bias + VectorUtils.DotProductWithOffset(in x, start, in feat);
}
float logZ = MathUtils.SoftMax(scores.AsSpan(0, _numClasses));
float datumLoss = logZ;
int lab = (int)label;
Contracts.Assert(0 <= lab && lab < _numClasses);
for (int c = 0, start = _numClasses; c < _numClasses; c++, start += NumFeatures)
{
float probLabel = lab == c ? 1 : 0;
datumLoss -= probLabel * scores[c];
float modelProb = MathUtils.ExpSlow(scores[c] - logZ);
float mult = weight * (modelProb - probLabel);
VectorUtils.AddMultWithOffset(in feat, mult, ref grad, start);
// Due to the call to EnsureBiases, we know this region is dense.
var editor = VBufferEditor.CreateFromBuffer(ref grad);
Contracts.Assert(editor.Values.Length >= BiasCount && (grad.IsDense || editor.Indices[BiasCount - 1] == BiasCount - 1));
editor.Values[c] += mult;
}
Contracts.Check(FloatUtils.IsFinite(datumLoss), "Data contain bad values.");
return weight * datumLoss;
}
private protected override VBuffer<float> InitializeWeightsFromPredictor(IPredictor srcPredictor)
{
var pred = srcPredictor as MaximumEntropyModelParameters;
Contracts.AssertValue(pred);
Contracts.Assert(pred.InputType.GetVectorSize() > 0);
// REVIEW: Support initializing the weights of a superset of features.
if (pred.InputType.GetVectorSize() != NumFeatures)
throw Contracts.Except("The input training data must have the same features used to train the input predictor.");
return InitializeWeights(pred.DenseWeightsEnumerable(), pred.GetBiases());
}
private protected override MaximumEntropyModelParameters CreatePredictor()
{
if (_numClasses < 1)
throw Contracts.Except("Cannot create a multiclass predictor with {0} classes", _numClasses);
if (_numClasses == 1)
{
using (var ch = Host.Start("Creating Predictor"))
{
ch.Warning("Training resulted in a one class predictor");
}
}
return new MaximumEntropyModelParameters(Host, in CurrentWeights, _numClasses, NumFeatures, _labelNames, _stats);
}
private protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.Factory cursorFactory, float loss, int numParams)
{
Contracts.AssertValue(ch);
Contracts.AssertValue(cursorFactory);
Contracts.Assert(NumGoodRows > 0);
Contracts.Assert(WeightSum > 0);
Contracts.Assert(BiasCount == _numClasses);
Contracts.Assert(loss >= 0);
Contracts.Assert(numParams >= BiasCount);
Contracts.Assert(CurrentWeights.IsDense);
ch.Info("Model trained with {0} training examples.", NumGoodRows);
// Compute deviance: start with loss function.
float deviance = (float)(2 * loss * WeightSum);
if (L2Weight > 0)
{
// Need to subtract L2 regularization loss.
// The bias term is not regularized.
var regLoss = VectorUtils.NormSquared(CurrentWeights.GetValues().Slice(BiasCount)) * L2Weight;
deviance -= regLoss;
}
if (L1Weight > 0)
{
// Need to subtract L1 regularization loss.
// The bias term is not regularized.
Double regLoss = 0;
VBufferUtils.ForEachDefined(in CurrentWeights, (ind, value) => { if (ind >= BiasCount) regLoss += Math.Abs(value); });
deviance -= (float)regLoss * L1Weight * 2;
}
ch.Info("Residual Deviance: \t{0}", deviance);
// Compute null deviance, i.e., the deviance of null hypothesis.
// Cap the prior positive rate at 1e-15.
float nullDeviance = 0;
for (int iLabel = 0; iLabel < _numClasses; iLabel++)
{
Contracts.Assert(_prior[iLabel] >= 0);
if (_prior[iLabel] == 0)
continue;
nullDeviance -= (float)(2 * _prior[iLabel] * Math.Log(_prior[iLabel] / WeightSum));
}
ch.Info("Null Deviance: \t{0}", nullDeviance);
// Compute AIC.
ch.Info("AIC: \t{0}", 2 * numParams + deviance);
// REVIEW: Figure out how to compute the statistics for the coefficients.
_stats = new ModelStatisticsBase(Host, NumGoodRows, numParams, deviance, nullDeviance);
}
private protected override void ProcessPriorDistribution(float label, float weight)
{
int iLabel = (int)label;
Contracts.Assert(0 <= iLabel && iLabel < _numClasses);
_prior[iLabel] += weight;
}
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);
var metadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
.Concat(AnnotationUtils.GetTrainerOutputAnnotation()));
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, metadata)
};
}
private protected override MulticlassPredictionTransformer<MaximumEntropyModelParameters> MakeTransformer(MaximumEntropyModelParameters model, DataViewSchema trainSchema)
=> new MulticlassPredictionTransformer<MaximumEntropyModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
/// <summary>
/// Continues the training of a <see cref="LbfgsMaximumEntropyMulticlassTrainer"/> using an already trained <paramref name="modelParameters"/> and returns
/// a <see cref="MulticlassPredictionTransformer{MulticlassLogisticRegressionModelParameters}"/>.
/// </summary>
public MulticlassPredictionTransformer<MaximumEntropyModelParameters> Fit(IDataView trainData, MaximumEntropyModelParameters modelParameters)
=> TrainTransformer(trainData, initPredictor: modelParameters);
[TlcModule.EntryPoint(Name = "Trainers.LogisticRegressionClassifier",
Desc = LbfgsMaximumEntropyMulticlassTrainer.Summary,
UserName = LbfgsMaximumEntropyMulticlassTrainer.UserNameValue,
ShortName = LbfgsMaximumEntropyMulticlassTrainer.ShortName)]
internal static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, LbfgsMaximumEntropyMulticlassTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainLRMultiClass");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return TrainerEntryPointsUtils.Train<LbfgsMaximumEntropyMulticlassTrainer.Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
() => new LbfgsMaximumEntropyMulticlassTrainer(host, input),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName));
}
}
/// <summary>
/// Common linear model of multiclass classifiers. <see cref="LinearMulticlassModelParameters"/> contains a single
/// linear model per class.
/// </summary>
public abstract class LinearMulticlassModelParametersBase :
ModelParametersBase<VBuffer<float>>,
IValueMapper,
ICanSaveInTextFormat,
ICanSaveInSourceCode,
ICanSaveSummary,
ICanGetSummaryInKeyValuePairs,
ICanGetSummaryAsIDataView,
ICanGetSummaryAsIRow,
ISingleCanSavePfa,
ISingleCanSaveOnnx
{
private const string ModelStatsSubModelFilename = "ModelStats";
private const string LabelNamesSubModelFilename = "LabelNames";
private protected readonly int NumberOfClasses;
private protected readonly int NumberOfFeatures;
// The label names used to write model summary. Either null or of length _numClasses.
private readonly string[] _labelNames;
private protected readonly float[] Biases;
private protected readonly VBuffer<float>[] Weights;
public readonly ModelStatisticsBase Statistics;
// This stores the _weights matrix in dense format for performance.
// It is used to make efficient predictions when the instance is sparse, so we get
// dense-sparse dot products and avoid the sparse-sparse case.
// When the _weights matrix is dense to begin with, then _weights == _weightsDense at all times after construction.
// When _weights is sparse, then this remains null until we see the first sparse instance,
// at which point it is initialized.
private volatile VBuffer<float>[] _weightsDense;
private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
internal readonly DataViewType InputType;
internal readonly DataViewType OutputType;
DataViewType IValueMapper.InputType => InputType;
DataViewType IValueMapper.OutputType => OutputType;
bool ICanSavePfa.CanSavePfa => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
internal LinearMulticlassModelParametersBase(IHostEnvironment env, string name, in VBuffer<float> weights, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null)
: base(env, name)
{
Contracts.Assert(weights.Length == numClasses + numClasses * numFeatures);
NumberOfClasses = numClasses;
NumberOfFeatures = numFeatures;
// weights contains both bias and feature weights in a flat vector
// Biases are stored in the first _numClass elements
// followed by one weight vector for each class, in turn, all concatenated
// (i.e.: in "row major", if we encode each weight vector as a row of a matrix)
Contracts.Assert(weights.Length == NumberOfClasses + NumberOfClasses * NumberOfFeatures);
Biases = new float[NumberOfClasses];
for (int i = 0; i < Biases.Length; i++)
weights.GetItemOrDefault(i, ref Biases[i]);
Weights = new VBuffer<float>[NumberOfClasses];
for (int i = 0; i < Weights.Length; i++)
weights.CopyTo(ref Weights[i], NumberOfClasses + i * NumberOfFeatures, NumberOfFeatures);
if (Weights.All(v => v.IsDense))
_weightsDense = Weights;
InputType = new VectorDataViewType(NumberDataViewType.Single, NumberOfFeatures);
OutputType = new VectorDataViewType(NumberDataViewType.Single, NumberOfClasses);
Contracts.Assert(labelNames == null || labelNames.Length == numClasses);
_labelNames = labelNames;
Contracts.AssertValueOrNull(stats);
Statistics = stats;
}
/// <summary>
/// Initializes a new instance of the <see cref="MaximumEntropyModelParameters"/> class.
/// This constructor is called by <see cref="SdcaMaximumEntropyMulticlassTrainer"/> to create the predictor.
/// </summary>
/// <param name="env">The host environment.</param>
/// <param name="name">Registration name of this model's actual type.</param>
/// <param name="weights">The array of weights vectors. It should contain <paramref name="numClasses"/> weights.</param>
/// <param name="bias">The array of biases. It should contain contain <paramref name="numClasses"/> weights.</param>
/// <param name="numClasses">The number of classes for multi-class classification. Must be at least 2.</param>
/// <param name="numFeatures">The length of the feature vector.</param>
/// <param name="labelNames">The optional label names. If specified not null, it should have the same length as <paramref name="numClasses"/>.</param>
/// <param name="stats">The model statistics.</param>
internal LinearMulticlassModelParametersBase(IHostEnvironment env, string name, VBuffer<float>[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null)
: base(env, name)
{
Contracts.CheckValue(weights, nameof(weights));
Contracts.CheckValue(bias, nameof(bias));
Contracts.CheckParam(numClasses >= 2, nameof(numClasses), "Must be at least 2.");
NumberOfClasses = numClasses;
Contracts.CheckParam(numFeatures >= 1, nameof(numFeatures), "Must be positive.");
NumberOfFeatures = numFeatures;
Contracts.Check(Utils.Size(weights) == NumberOfClasses);
Contracts.Check(Utils.Size(bias) == NumberOfClasses);
Weights = new VBuffer<float>[NumberOfClasses];
Biases = new float[NumberOfClasses];
for (int iClass = 0; iClass < NumberOfClasses; iClass++)
{
Contracts.Assert(weights[iClass].Length == NumberOfFeatures);
weights[iClass].CopyTo(ref Weights[iClass]);
Biases[iClass] = bias[iClass];
}
if (Weights.All(v => v.IsDense))
_weightsDense = Weights;
InputType = new VectorDataViewType(NumberDataViewType.Single, NumberOfFeatures);
OutputType = new VectorDataViewType(NumberDataViewType.Single, NumberOfClasses);
Contracts.Assert(labelNames == null || labelNames.Length == numClasses);
_labelNames = labelNames;
Contracts.AssertValueOrNull(stats);
Statistics = stats;
}
private protected LinearMulticlassModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
: base(env, name, ctx)
{
// *** Binary format ***
// int: number of features
// int: number of classes = number of biases
// float[]: biases
// (weight matrix, in CSR if sparse)
// (see https://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000)
// int: number of row start indices (_numClasses + 1 if sparse, 0 if dense)
// int[]: row start indices
// int: total number of column indices (0 if dense)
// int[]: column index of each non-zero weight
// int: total number of non-zero weights (same as number of column indices if sparse, num of classes * num of features if dense)
// float[]: non-zero weights
// int[]: Id of label names (optional, in a separate stream)
// ModelStatisticsBase: model statistics (optional, in a separate stream)
NumberOfFeatures = ctx.Reader.ReadInt32();
Host.CheckDecode(NumberOfFeatures >= 1);
NumberOfClasses = ctx.Reader.ReadInt32();
Host.CheckDecode(NumberOfClasses >= 1);
Biases = ctx.Reader.ReadFloatArray(NumberOfClasses);
int numStarts = ctx.Reader.ReadInt32();
if (numStarts == 0)
{
// The weights are entirely dense.
int numIndices = ctx.Reader.ReadInt32();
Host.CheckDecode(numIndices == 0);
int numWeights = ctx.Reader.ReadInt32();
Host.CheckDecode(numWeights == NumberOfClasses * NumberOfFeatures);
Weights = new VBuffer<float>[NumberOfClasses];
for (int i = 0; i < Weights.Length; i++)
{
var w = ctx.Reader.ReadFloatArray(NumberOfFeatures);
Weights[i] = new VBuffer<float>(NumberOfFeatures, w);
}
_weightsDense = Weights;
}
else
{
// Read weight matrix as CSR.
Host.CheckDecode(numStarts == NumberOfClasses + 1);
int[] starts = ctx.Reader.ReadIntArray(numStarts);
Host.CheckDecode(starts[0] == 0);
Host.CheckDecode(Utils.IsMonotonicallyIncreasing(starts));
int numIndices = ctx.Reader.ReadInt32();
Host.CheckDecode(numIndices == starts[starts.Length - 1]);
var indices = new int[NumberOfClasses][];
for (int i = 0; i < indices.Length; i++)
{
indices[i] = ctx.Reader.ReadIntArray(starts[i + 1] - starts[i]);
Host.CheckDecode(Utils.IsIncreasing(0, indices[i], NumberOfFeatures));
}
int numValues = ctx.Reader.ReadInt32();
Host.CheckDecode(numValues == numIndices);
Weights = new VBuffer<float>[NumberOfClasses];
for (int i = 0; i < Weights.Length; i++)
{
float[] values = ctx.Reader.ReadFloatArray(starts[i + 1] - starts[i]);
Weights[i] = new VBuffer<float>(NumberOfFeatures, Utils.Size(values), values, indices[i]);
}
}
WarnOnOldNormalizer(ctx, GetType(), Host);
InputType = new VectorDataViewType(NumberDataViewType.Single, NumberOfFeatures);
OutputType = new VectorDataViewType(NumberDataViewType.Single, NumberOfClasses);
// REVIEW: Should not save the label names duplicately with the predictor again.
// Get it from the label column schema metadata instead.
string[] labelNames = null;
if (ctx.TryLoadBinaryStream(LabelNamesSubModelFilename, r => labelNames = LoadLabelNames(ctx, r)))
_labelNames = labelNames;
// backwards compatibility:MLR used to serialize a LinearModelSStatistics object, before there existed two separate classes
// for ModelStatisticsBase and LinearModelParameterStatistics.
// It always only populated only the fields now found on ModelStatisticsBase.
ModelStatisticsBase stats;
ctx.LoadModelOrNull<ModelStatisticsBase, SignatureLoadModel>(Host, out stats, ModelStatsSubModelFilename);
Statistics = stats;
}
private protected abstract VersionInfo GetVersionInfo();
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Host.Assert(Biases.Length == NumberOfClasses);
Host.Assert(Biases.Length == Weights.Length);
#if DEBUG
foreach (var fw in Weights)
Host.Assert(fw.Length == NumberOfFeatures);
#endif
// *** Binary format ***
// int: number of features
// int: number of classes = number of biases
// float[]: biases
// (weight matrix, in CSR if sparse)
// (see https://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000)
// int: number of row start indices (_numClasses + 1 if sparse, 0 if dense)
// int[]: row start indices
// int: total number of column indices (0 if dense)
// int[]: column index of each non-zero weight
// int: total number of non-zero weights (same as number of column indices if sparse, num of classes * num of features if dense)
// float[]: non-zero weights
// bool: whether label names are present
// int[]: Id of label names (optional, in a separate stream)
// LinearModelParameterStatistics: model statistics (optional, in a separate stream)
ctx.Writer.Write(NumberOfFeatures);
ctx.Writer.Write(NumberOfClasses);
ctx.Writer.WriteSinglesNoCount(Biases.AsSpan(0, NumberOfClasses));
// _weights == _weighsDense means we checked that all vectors in _weights
// are actually dense, and so we assigned the same object, or it came dense
// from deserialization.
if (Weights == _weightsDense)
{
ctx.Writer.Write(0); // Number of starts.
ctx.Writer.Write(0); // Number of indices.
ctx.Writer.Write(NumberOfFeatures * Weights.Length);
foreach (var fv in Weights)
{
Host.Assert(fv.Length == NumberOfFeatures);
ctx.Writer.WriteSinglesNoCount(fv.GetValues());
}
}
else
{
// Number of starts.
ctx.Writer.Write(NumberOfClasses + 1);
// Starts always starts with 0.
int numIndices = 0;
ctx.Writer.Write(numIndices);
for (int i = 0; i < Weights.Length; i++)
{
// REVIEW: Assuming the presence of *any* zero justifies
// writing in sparse format seems stupid, but might be difficult
// to change without changing the format since the presence of
// any sparse vector means we're writing indices anyway. Revisit.
// This is actually a bug waiting to happen: sparse/dense vectors
// can have different dot products even if they are logically the
// same vector.
numIndices += NonZeroCount(in Weights[i]);
ctx.Writer.Write(numIndices);
}
ctx.Writer.Write(numIndices);
{
// just scoping the count so we can use another further down
int count = 0;
foreach (var fw in Weights)
{
var fwValues = fw.GetValues();
if (fw.IsDense)
{
for (int i = 0; i < fwValues.Length; i++)
{
if (fwValues[i] != 0)
{
ctx.Writer.Write(i);
count++;
}
}
}
else
{
var fwIndices = fw.GetIndices();
ctx.Writer.WriteIntsNoCount(fwIndices);
count += fwIndices.Length;
}
}
Host.Assert(count == numIndices);
}
ctx.Writer.Write(numIndices);
{
int count = 0;
foreach (var fw in Weights)
{
var fwValues = fw.GetValues();
if (fw.IsDense)
{
for (int i = 0; i < fwValues.Length; i++)
{
if (fwValues[i] != 0)
{
ctx.Writer.Write(fwValues[i]);
count++;
}
}
}
else
{
ctx.Writer.WriteSinglesNoCount(fwValues);
count += fwValues.Length;
}
}
Host.Assert(count == numIndices);
}
}
Contracts.AssertValueOrNull(_labelNames);
if (_labelNames != null)
ctx.SaveBinaryStream(LabelNamesSubModelFilename, w => SaveLabelNames(ctx, w));
Contracts.AssertValueOrNull(Statistics);
if (Statistics != null)
ctx.SaveModel(Statistics, ModelStatsSubModelFilename);
}
// REVIEW: Destroy.
private static int NonZeroCount(in VBuffer<float> vector)
{
int count = 0;
var values = vector.GetValues();
for (int i = 0; i < values.Length; i++)
{
if (values[i] != 0)
count++;
}
return count;
}
ValueMapper<TSrc, TDst> IValueMapper.GetMapper<TSrc, TDst>()
{
Host.Check(typeof(TSrc) == typeof(VBuffer<float>), "Invalid source type in GetMapper");
Host.Check(typeof(TDst) == typeof(VBuffer<float>), "Invalid destination type in GetMapper");
ValueMapper<VBuffer<float>, VBuffer<float>> del =
(in VBuffer<float> src, ref VBuffer<float> dst) =>
{
Host.Check(src.Length == NumberOfFeatures);
PredictCore(in src, ref dst);
};
return (ValueMapper<TSrc, TDst>)(Delegate)del;
}
private void PredictCore(in VBuffer<float> src, ref VBuffer<float> dst)
{
Host.Check(src.Length == NumberOfFeatures, "src length should equal the number of features");
var weights = Weights;
if (!src.IsDense)
weights = DensifyWeights();
var editor = VBufferEditor.Create(ref dst, NumberOfClasses);
for (int i = 0; i < Biases.Length; i++)
editor.Values[i] = Biases[i] + VectorUtils.DotProduct(in weights[i], in src);
Calibrate(editor.Values);
dst = editor.Commit();
}
private VBuffer<float>[] DensifyWeights()
{
if (_weightsDense == null)
{
lock (Weights)
{
if (_weightsDense == null)
{
var weightsDense = new VBuffer<float>[NumberOfClasses];
for (int i = 0; i < Weights.Length; i++)
{
// Haven't yet created dense version of the weights.
// REVIEW: Should we always expand to full weights or should this be subject to an option?
var w = Weights[i];
if (w.IsDense)
weightsDense[i] = w;
else
w.CopyToDense(ref weightsDense[i]);
}
_weightsDense = weightsDense;
}
}
Host.AssertValue(_weightsDense);
}
return _weightsDense;
}
/// <summary>
/// Post-processing function applied to scores of each class' linear model output.
/// In <see cref="PredictCore(in VBuffer{float}, ref VBuffer{float})"/> we compute the i-th class' score
/// by using inner product of the i-th linear coefficient vector <see cref="Weights"/>[i] and the input feature vector (plus bias).
/// Then, <see cref="Calibrate(Span{float})"/> will be called to adjust those raw scores.
/// </summary>
private protected abstract void Calibrate(Span<float> dst);
IList<KeyValuePair<string, object>> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema)
{
Host.CheckValueOrNull(schema);
List<KeyValuePair<string, object>> results = new List<KeyValuePair<string, object>>();
var names = default(VBuffer<ReadOnlyMemory<char>>);
AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumberOfFeatures, ref names);
for (int classNumber = 0; classNumber < Biases.Length; classNumber++)
{
results.Add(new KeyValuePair<string, object>(
string.Format("{0}+(Bias)", GetLabelName(classNumber)),
Biases[classNumber]
));
}
for (int classNumber = 0; classNumber < Weights.Length; classNumber++)
{
var orderedWeights = Weights[classNumber].Items().OrderByDescending(kv => Math.Abs(kv.Value));
foreach (var weight in orderedWeights)
{
var value = weight.Value;
if (value == 0)
break;
int index = weight.Key;
var name = names.GetItemOrDefault(index);
results.Add(new KeyValuePair<string, object>(
string.Format("{0}+{1}", GetLabelName(classNumber), name.IsEmpty ? $"f{index}" : name.ToString()),
value
));
}
}
return results;
}
/// <summary>
/// Actual implementation of <see cref="ICanSaveInTextFormat.SaveAsText(TextWriter, RoleMappedSchema)"/> should happen in derived classes.
/// </summary>
private void SaveAsTextCore(TextWriter writer, RoleMappedSchema schema)
{
writer.WriteLine(GetTrainerName() + " bias and non-zero weights");
foreach (var namedValues in ((ICanGetSummaryInKeyValuePairs)this).GetSummaryInKeyValuePairs(schema))
{
Host.Assert(namedValues.Value is float);
writer.WriteLine("\t{0}\t{1}", namedValues.Key, (float)namedValues.Value);
}
if (Statistics != null)
Statistics.SaveText(writer, schema.Feature.Value, 20);
}
private protected abstract string GetTrainerName();
/// <summary>
/// Redirect <see cref="ICanSaveInTextFormat.SaveAsText(TextWriter, RoleMappedSchema)"/> call to the right function.
/// </summary>
void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) => SaveAsTextCore(writer, schema);
/// <summary>
/// Summary is equivalent to its information in text format.
/// </summary>
void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
{
((ICanSaveInTextFormat)this).SaveAsText(writer, schema);
}
/// <summary>
/// Actual implementation of <see cref="ICanSaveInSourceCode.SaveAsCode(TextWriter, RoleMappedSchema)"/> should happen in derived classes.
/// </summary>
private void SaveAsCodeCore(TextWriter writer, RoleMappedSchema schema)
{
Host.CheckValue(writer, nameof(writer));
Host.CheckValueOrNull(schema);
writer.WriteLine(string.Format("var scores = new float[{0}];", NumberOfClasses));
for (int i = 0; i < Biases.Length; i++)
{
LinearPredictorUtils.SaveAsCode(writer,
in Weights[i],
Biases[i],
schema,
"scores[" + i.ToString() + "]");
}
}
/// <summary>
/// The raw scores of all linear classifiers are stored in <see langword="float"/>[] <paramref name="scoresName"/>.
/// Derived classes can use this functin to add C# code for post-transformation.
/// </summary>
private protected abstract void SavePostTransformAsCode(TextWriter writer, string scoresName);
/// <summary>
/// Redirect <see cref="ICanSaveInSourceCode.SaveAsCode(TextWriter, RoleMappedSchema)"/> call to the right function.
/// </summary>
void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) => SaveAsCodeCore(writer, schema);
/// <summary>
/// Actual implementation of <see cref="ISingleCanSavePfa.SaveAsPfa(BoundPfaContext, JToken)"/> should happen in derived classes.
/// </summary>
private JToken SaveAsPfaCore(BoundPfaContext ctx, JToken input)
{
Host.CheckValue(ctx, nameof(ctx));
Host.CheckValue(input, nameof(input));
const string typeName = "MCLinearPredictor";
JToken typeDecl = typeName;
if (ctx.Pfa.RegisterType(typeName))
{
JObject type = new JObject();
type["type"] = "record";
type["name"] = typeName;
JArray fields = new JArray();
JObject jobj = null;
fields.Add(jobj.AddReturn("name", "coeff").AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Array(PfaUtils.Type.Double))));
fields.Add(jobj.AddReturn("name", "const").AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)));
type["fields"] = fields;
typeDecl = type;
}
JObject predictor = new JObject();
predictor["coeff"] = new JArray(Weights.Select(w => new JArray(w.DenseValues())));
predictor["const"] = new JArray(Biases);
var cell = ctx.DeclareCell("MCLinearPredictor", typeDecl, predictor);
var cellRef = PfaUtils.Cell(cell);
return ApplyPfaPostTransform(PfaUtils.Call("model.reg.linear", input, cellRef));
}
/// <summary>
/// This is called at the end of <see cref="SaveAsPfaCore(BoundPfaContext, JToken)"/> to adjust the final outputs of all linear models.
/// </summary>
private protected abstract JToken ApplyPfaPostTransform(JToken input);
/// <summary>
/// Redirect <see cref="ISingleCanSavePfa.SaveAsPfa(BoundPfaContext, JToken)"/> call to the right function.
/// </summary>
JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) => SaveAsPfaCore(ctx, input);
/// <summary>
/// Actual implementation of <see cref="ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext, string[], string)"/> should happen in derived classes.
/// It's ok to make <see cref="SaveAsOnnxCore(OnnxContext, string[], string)"/> a <see langword="private protected"/> method in the future
/// if any derived class wants to override.
/// </summary>
private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureColumn)
{
Host.CheckValue(ctx, nameof(ctx));
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "MultiClassLogisticRegression");
Host.Assert(outputs[0] == DefaultColumnNames.PredictedLabel);
Host.Assert(outputs[1] == DefaultColumnNames.Score);
string classifierLabelOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "ClassifierLabelOutput", true);
string opType = "LinearClassifier";
var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { classifierLabelOutput, outputs[1] }, ctx.GetNodeName(opType));
node.AddAttribute("post_transform", GetOnnxPostTransform());
node.AddAttribute("multi_class", true);
node.AddAttribute("coefficients", Weights.SelectMany(w => w.DenseValues()));
node.AddAttribute("intercepts", Biases);
node.AddAttribute("classlabels_ints", Enumerable.Range(1, NumberOfClasses).Select(x => (long)x));
opType = "Unsqueeze";
var unsqueezeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "CastNodeOutput");
var unsqueezeNode = ctx.CreateNode(opType, classifierLabelOutput, unsqueezeOutput, ctx.GetNodeName(opType), "");
unsqueezeNode.AddAttribute("axes", new long[] { 1 });
// Onnx outputs an Int64, but ML.NET outputs UInt32. So cast the Onnx output here
opType = "Cast";
var castNode = ctx.CreateNode(opType, unsqueezeOutput, outputs[0], ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
castNode.AddAttribute("to", t);
return true;
}
/// <summary>
/// Post-transform applied to the raw scores produced by those linear models of all classes. For maximum entropy classification, it should be
/// a softmax function. This function is used only in <see cref="SaveAsOnnxCore(OnnxContext, string[], string)"/>.
/// </summary>
private protected abstract string GetOnnxPostTransform();
/// <summary>
/// Redirect <see cref="ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext, string[], string)"/> call to the right function.
/// </summary>
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) => SaveAsOnnxCore(ctx, outputs, featureColumn);
/// <summary>
/// Copies the weight vector for each class into a set of buffers.
/// </summary>
/// <param name="weights">A possibly reusable set of vectors, which will
/// be expanded as necessary to accommodate the data.</param>
/// <param name="numClasses">Set to the rank, which is also the logical length
/// of <paramref name="weights"/>.</param>
public void GetWeights(ref VBuffer<float>[] weights, out int numClasses)
{
numClasses = NumberOfClasses;
Utils.EnsureSize(ref weights, NumberOfClasses, NumberOfClasses);
for (int i = 0; i < NumberOfClasses; i++)
Weights[i].CopyTo(ref weights[i]);
}
/// <summary>
/// Gets the biases for the logistic regression predictor.
/// </summary>
public IEnumerable<float> GetBiases()
{
return Biases;
}
internal IEnumerable<float> DenseWeightsEnumerable()
{
Contracts.Assert(Weights.Length == Biases.Length);
int featuresCount = Weights[0].Length;
for (var i = 0; i < Weights.Length; i++)
{
Host.Assert(featuresCount == Weights[i].Length);
foreach (var weight in Weights[i].Items(all: true))
yield return weight.Value;
}
}
internal string GetLabelName(int classNumber)
{
const string classNumberFormat = "Class_{0}";
Contracts.Assert(0 <= classNumber && classNumber < NumberOfClasses);
return _labelNames == null ? string.Format(classNumberFormat, classNumber) : _labelNames[classNumber];
}
private string[] LoadLabelNames(ModelLoadContext ctx, BinaryReader reader)
{
Contracts.AssertValue(ctx);
Contracts.AssertValue(reader);
string[] labelNames = new string[NumberOfClasses];
for (int i = 0; i < NumberOfClasses; i++)
{
int id = reader.ReadInt32();
Host.CheckDecode(0 <= id && id < Utils.Size(ctx.Strings));
var str = ctx.Strings[id];
Host.CheckDecode(str.Length > 0);
labelNames[i] = str;
}
return labelNames;
}
private void SaveLabelNames(ModelSaveContext ctx, BinaryWriter writer)
{
Contracts.AssertValue(ctx);
Contracts.AssertValue(writer);
Contracts.Assert(Utils.Size(_labelNames) == NumberOfClasses);
for (int i = 0; i < NumberOfClasses; i++)
{
Host.AssertValue(_labelNames[i]);
writer.Write(ctx.Strings.Add(_labelNames[i]).Id);
}
}
IDataView ICanGetSummaryAsIDataView.GetSummaryDataView(RoleMappedSchema schema)
{
var bldr = new ArrayDataViewBuilder(Host);
ValueGetter<VBuffer<ReadOnlyMemory<char>>> getSlotNames =
(ref VBuffer<ReadOnlyMemory<char>> dst) =>
AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumberOfFeatures, ref dst);
// Add the bias and the weight columns.
bldr.AddColumn("Bias", NumberDataViewType.Single, Biases);
bldr.AddColumn("Weights", getSlotNames, NumberDataViewType.Single, Weights);
bldr.AddColumn("ClassNames", Enumerable.Range(0, NumberOfClasses).Select(i => GetLabelName(i)).ToArray());
return bldr.GetDataView();
}
DataViewRow ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema)
{
return null;
}
DataViewRow ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema)
{
if (Statistics == null)
return null;
var names = default(VBuffer<ReadOnlyMemory<char>>);
AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weights.Length, ref names);
var meta = Statistics.MakeStatisticsMetadata(schema, in names);
return AnnotationUtils.AnnotationsAsRow(meta);
}
}
/// <summary>
/// Linear model of multiclass classifiers. It outputs raw scores of all its linear models, and no probablistic output is provided.
/// </summary>
public sealed class LinearMulticlassModelParameters : LinearMulticlassModelParametersBase
{
internal const string LoaderSignature = "MulticlassLinear";
internal const string RegistrationName = "MulticlassLinearPredictor";
private static VersionInfo VersionInfo =>
new VersionInfo(
modelSignature: "MCLINEAR",
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(LinearMulticlassModelParameters).Assembly.FullName);
/// <summary>
/// Function used to pass <see cref="VersionInfo"/> into parent class. It may be used when saving the model.
/// </summary>
private protected override VersionInfo GetVersionInfo() => VersionInfo;
internal LinearMulticlassModelParameters(IHostEnvironment env, in VBuffer<float> weights, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null)
: base(env, RegistrationName, weights, numClasses, numFeatures, labelNames, stats)
{
}
internal LinearMulticlassModelParameters(IHostEnvironment env, VBuffer<float>[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null)
: base(env, RegistrationName, weights, bias, numClasses, numFeatures, labelNames, stats)
{
}
private LinearMulticlassModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
}
/// <summary>
/// This function does not do any calibration. It's common in multi-class support vector machines where probabilitic outputs are not provided.
/// </summary>
/// <param name="dst">Score vector should be calibrated.</param>
private protected override void Calibrate(Span<float> dst)
{
}
internal static LinearMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(VersionInfo);
return new LinearMulticlassModelParameters(env, ctx);
}
private protected override void SavePostTransformAsCode(TextWriter writer, string scoresName) { }
/// <summary>
/// No post-transform is needed for non-clibrated classifier.
/// </summary>
private protected override string GetOnnxPostTransform() => "NONE";
/// <summary>
/// No post-transform is needed for non-clibrated classifier.
/// </summary>
private protected override JToken ApplyPfaPostTransform(JToken input) => input;
private protected override string GetTrainerName() => nameof(LinearMulticlassModelParameters);
}
/// <summary>
/// Linear maximum entropy model of multiclass classifiers. It outputs classes probabilities.
/// This model is also known as multinomial logistic regression.
/// Please see https://en.wikipedia.org/wiki/Multinomial_logistic_regression for details.
/// </summary>
public sealed class MaximumEntropyModelParameters : LinearMulticlassModelParametersBase
{
internal const string LoaderSignature = "MultiClassLRExec";
internal const string RegistrationName = "MulticlassLogisticRegressionPredictor";
private static VersionInfo VersionInfo =>
new VersionInfo(
modelSignature: "MULTI LR",
// verWrittenCur: 0x00010001, // Initial
// verWrittenCur: 0x00010002, // Added class names
verWrittenCur: 0x00010003, // Added model stats
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(MaximumEntropyModelParameters).Assembly.FullName);
/// <summary>
/// Function used to pass <see cref="VersionInfo"/> into parent class. It may be used when saving the model.
/// </summary>
private protected override VersionInfo GetVersionInfo() => VersionInfo;
internal MaximumEntropyModelParameters(IHostEnvironment env, in VBuffer<float> weights, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null)
: base(env, RegistrationName, weights, numClasses, numFeatures, labelNames, stats)
{
}
internal MaximumEntropyModelParameters(IHostEnvironment env, VBuffer<float>[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null)
: base(env, RegistrationName, weights, bias, numClasses, numFeatures, labelNames, stats)
{
}
private MaximumEntropyModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
}
/// <summary>
/// This function applies softmax to <paramref name="dst"/>. For details about softmax, see https://en.wikipedia.org/wiki/Softmax_function.
/// </summary>
/// <param name="dst">Score vector should be calibrated.</param>
private protected override void Calibrate(Span<float> dst)
{
Host.Assert(dst.Length == NumberOfClasses);
// scores are in log-space; convert and fix underflow/overflow
// TODO: re-normalize probabilities to account for underflow/overflow?
float softmax = MathUtils.SoftMax(dst.Slice(0, NumberOfClasses));
for (int i = 0; i < NumberOfClasses; ++i)
dst[i] = MathUtils.ExpSlow(dst[i] - softmax);
}
/// <summary>
/// Apply softmax function to <paramref name="scoresName"/>, which contains raw scores from all linear models.
/// </summary>
private protected override void SavePostTransformAsCode(TextWriter writer, string scoresName)
{
writer.WriteLine(string.Format("var softmax = MathUtils.SoftMax({0}.AsSpan(0, {1}));", scoresName, NumberOfClasses));
for (int c = 0; c < Biases.Length; c++)
writer.WriteLine("{1}[{0}] = Math.Exp({1}[{0}] - softmax);", c, scoresName);
}
/// <summary>
/// Apply softmax to the raw scores produced by the lienar models of all classes.
/// </summary>
private protected override string GetOnnxPostTransform() => "SOFTMAX";
/// <summary>
/// Apply softmax to the raw scores produced by the lienar models of all classes.
/// </summary>
private protected override JToken ApplyPfaPostTransform(JToken input) => PfaUtils.Call("m.link.softmax", input);
private protected override string GetTrainerName() => nameof(LbfgsMaximumEntropyMulticlassTrainer);
}
}
|