|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Linq;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
[assembly: LoadableClass(SdcaMaximumEntropyMulticlassTrainer.Summary, typeof(SdcaMaximumEntropyMulticlassTrainer), typeof(SdcaMaximumEntropyMulticlassTrainer.Options),
new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
SdcaMaximumEntropyMulticlassTrainer.UserNameValue,
SdcaMaximumEntropyMulticlassTrainer.LoadNameValue,
SdcaMaximumEntropyMulticlassTrainer.ShortName)]
namespace Microsoft.ML.Trainers
{
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> to predict a target using a linear multiclass classifier model trained with a coordinate descent method.
/// Depending on the used loss function, the trained model can be, for example, maximum entropy classifier or multi-class support vector machine.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer for maximum entropy classifier, use [SdcaMaximumEntropy](xref:Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,System.Nullable{System.Single},System.Nullable{System.Single},System.Nullable{System.Int32})) or
/// [SdcaMaximumEntropy(Options)](xref:Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,Microsoft.ML.Trainers.SdcaMaximumEntropyMulticlassTrainer.Options)).
/// To create this trainer for a [loss function](xref:Microsoft.ML.Trainers.ISupportSdcaClassificationLoss) (such as support vector machine's [hinge loss](xref:Microsoft.ML.Trainers.HingeLoss)) of your choice,
/// use [SdcaNonCalibrated](xref:Microsoft.ML.StandardTrainersCatalog.SdcaNonCalibrated(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Microsoft.ML.Trainers.ISupportSdcaClassificationLoss,System.Nullable{System.Single},System.Nullable{System.Single},System.Nullable{System.Int32})) or
/// [SdcaNonCalibrated(Options)](xref:Microsoft.ML.StandardTrainersCatalog.SdcaNonCalibrated(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,Microsoft.ML.Trainers.SdcaNonCalibratedMulticlassTrainer.Options)).
///
/// [!include[io](~/../docs/samples/docs/api-reference/io-columns-multiclass-classification.md)]
///
/// ### Trainer Characteristics
/// | | |
/// | -- | -- |
/// | Machine learning task | Multiclass classification |
/// | Is normalization required? | Yes |
/// | Is caching required? | No |
/// | Required NuGet in addition to Microsoft.ML | None |
/// | Exportable to ONNX | Yes |
///
/// ### Scoring Function
/// This trains linear model to solve multiclass classification problems.
/// Assume that the number of classes is $m$ and number of features is $n$.
/// It assigns the $c$-th class a coefficient vector $\textbf{w}_c \in {\mathbb R}^n$ and a bias $b_c \in {\mathbb R}$, for $c=1,\dots,m$.
/// Given a feature vector $\textbf{x} \in {\mathbb R}^n$, the $c$-th class's score would be $\hat{y}^c = \textbf{w}_c^T \textbf{x} + b_c$.
/// If $\textbf{x}$ belongs to class $c$, then $\hat{y}^c$ should be much larger than 0.
/// In contrast, a $\hat{y}^c$ much smaller than 0 means the desired label should not be $c$.
///
/// If and only if the trained model is a maximum entropy classifier, you can interpret the output score vector as the predicted class probabilities because [softmax function](https://en.wikipedia.org/wiki/Softmax_function) may be applied to post-process all classes' scores.
/// More specifically, the probability of $\textbf{x}$ belonging to class $c$ is computed by $\tilde{P}( c | \textbf{x} ) = \frac{ e^{\hat{y}^c} }{ \sum_{c' = 1}^m e^{\hat{y}^{c'}} }$ and store at the $c$-th element in the score vector.
/// In other cases, the output score vector is just $[\hat{y}^1, \dots, \hat{y}^m]$.
///
/// ### Training Algorithm Details
/// The optimization algorithm is an extension of [a coordinate descent method](http://jmlr.org/papers/volume14/shalev-shwartz13a/shalev-shwartz13a.pdf)
/// following a similar path proposed in an earlier [paper](https://www.csie.ntu.edu.tw/~cjlin/papers/maxent_dual.pdf).
/// It is usually much faster than [L-BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) and
/// [truncated Newton methods](https://en.wikipedia.org/wiki/Truncated_Newton_method) for large-scale and sparse data sets.
///
/// [!include[regularization](~/../docs/samples/docs/api-reference/regularization-l1-l2.md)]
///
/// Check the See Also section for links to usage examples.
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(MulticlassClassificationCatalog.MulticlassClassificationTrainers, SdcaMaximumEntropyMulticlassTrainer.Options)"/>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, float?, float?, int?)"/>
/// <seealso cref="Microsoft.ML.Trainers.SdcaMaximumEntropyMulticlassTrainer.Options"/>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.SdcaNonCalibrated(MulticlassClassificationCatalog.MulticlassClassificationTrainers, SdcaNonCalibratedMulticlassTrainer.Options)"/>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.SdcaNonCalibrated(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, ISupportSdcaClassificationLoss, float?, float?, int?)"/>
/// <seealso cref="Microsoft.ML.Trainers.SdcaNonCalibratedMulticlassTrainer.Options"/>
public abstract class SdcaMulticlassTrainerBase<TModel> : SdcaTrainerBase<SdcaMulticlassTrainerBase<TModel>.MulticlassOptions, MulticlassPredictionTransformer<TModel>, TModel>
where TModel : class
{
internal const string LoadNameValue = "SDCAMC";
internal const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)";
internal const string ShortName = "sasdcamc";
internal const string Summary = "The SDCA linear multi-class classification trainer.";
/// <summary>
/// Options for the <see cref="SdcaMulticlassTrainerBase{TModel}"/>.
/// </summary>
public class MulticlassOptions : OptionsBase
{
/// <summary>
/// The custom <a href="https://en.wikipedia.org/wiki/Loss_function">loss</a>.
/// </summary>
/// <value>
/// If unspecified, <see cref="LogLoss"/> will be used.
/// </value>
[Argument(ArgumentType.Multiple, Name = "LossFunction", HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
internal ISupportSdcaClassificationLossFactory LossFunctionFactory = new LogLossFactory();
/// <summary>
/// Internal state of <see cref="SdcaNonCalibratedMulticlassTrainer.Options.Loss"/> or storage of
/// a customized loss passed in. <see cref="SdcaMaximumEntropyMulticlassTrainer.Options"/> cannot set this field because its
/// loss function is always <see cref="LogLoss"/>. In addition, <see cref="InternalLoss"/> and <see cref="LogLossFactory"/> are
/// the two fields used to determined the actual loss function inside the training framework of <see cref="SdcaMulticlassTrainerBase{TModel}"/>.
/// </summary>
internal ISupportSdcaClassificationLoss InternalLoss;
}
private readonly ISupportSdcaClassificationLoss _loss;
private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
/// <summary>
/// Initializes a new instance of <see cref="SdcaMulticlassTrainerBase{TModel}"/>.
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="labelColumn">The label, or dependent variable.</param>
/// <param name="featureColumn">The features, or independent variables.</param>
/// <param name="weights">The optional example weights.</param>
/// <param name="loss">The custom loss.</param>
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
internal SdcaMulticlassTrainerBase(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
ISupportSdcaClassificationLoss loss = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null)
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights),
l2Const, l1Threshold, maxIterations)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
_loss = loss ?? SdcaTrainerOptions.InternalLoss ?? SdcaTrainerOptions.LossFunctionFactory.CreateComponent(env);
Loss = _loss;
}
internal SdcaMulticlassTrainerBase(IHostEnvironment env, MulticlassOptions options,
string featureColumn, string labelColumn, string weightColumn = null)
: base(env, options, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
{
Host.CheckValue(labelColumn, nameof(labelColumn));
Host.CheckValue(featureColumn, nameof(featureColumn));
_loss = options.InternalLoss ?? options.LossFunctionFactory.CreateComponent(env);
Loss = _loss;
}
internal SdcaMulticlassTrainerBase(IHostEnvironment env, MulticlassOptions options)
: this(env, options, options.FeatureColumnName, options.LabelColumnName, options.ExampleWeightColumnName)
{
}
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);
var metadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
.Concat(AnnotationUtils.GetTrainerOutputAnnotation()));
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, metadata)
};
}
/// <inheritdoc/>
private protected override void TrainWithoutLock(IProgressChannelProvider progress, FloatLabelCursor.Factory cursorFactory, Random rand,
IdToIdxLookup idToIdx, int numThreads, DualsTableBase duals, float[] biasReg, float[] invariants, float lambdaNInv,
VBuffer<float>[] weights, float[] biasUnreg, VBuffer<float>[] l1IntermediateWeights, float[] l1IntermediateBias, float[] featureNormSquared)
{
Contracts.AssertValueOrNull(progress);
Contracts.Assert(SdcaTrainerOptions.L1Regularization.HasValue);
Contracts.AssertValueOrNull(idToIdx);
Contracts.AssertValueOrNull(invariants);
Contracts.AssertValueOrNull(featureNormSquared);
int numClasses = Utils.Size(weights);
Contracts.Assert(Utils.Size(biasReg) == numClasses);
Contracts.Assert(Utils.Size(biasUnreg) == numClasses);
int maxUpdateTrials = 2 * numThreads;
var l1Threshold = SdcaTrainerOptions.L1Regularization.Value;
bool l1ThresholdZero = l1Threshold == 0;
var lr = SdcaTrainerOptions.BiasLearningRate * SdcaTrainerOptions.L2Regularization.Value;
var pch = progress != null ? progress.StartProgressChannel("Dual update") : null;
using (pch)
using (var cursor = SdcaTrainerOptions.Shuffle ? cursorFactory.Create(rand) : cursorFactory.Create())
{
long rowCount = 0;
if (pch != null)
pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, rowCount));
Func<DataViewRowId, long> getIndexFromId = GetIndexFromIdGetter(idToIdx, biasReg.Length);
while (cursor.MoveNext())
{
Host.CheckAlive();
long idx = getIndexFromId(cursor.Id);
long dualIndexInitPos = idx * numClasses;
var features = cursor.Features;
var label = (int)cursor.Label;
float invariant;
float normSquared;
if (invariants != null)
{
invariant = invariants[idx];
Contracts.AssertValue(featureNormSquared);
normSquared = featureNormSquared[idx];
}
else
{
normSquared = VectorUtils.NormSquared(in features);
if (SdcaTrainerOptions.BiasLearningRate == 0)
normSquared += 1;
invariant = _loss.ComputeDualUpdateInvariant(2 * normSquared * lambdaNInv * GetInstanceWeight(cursor));
}
// The output for the label class using current weights and bias.
var labelOutput = WDot(in features, in weights[label], biasReg[label] + biasUnreg[label]);
var instanceWeight = GetInstanceWeight(cursor);
// This will be the new dual variable corresponding to the label class.
float labelDual = 0;
// This will be used to update the weights and regularized bias corresponding to the label class.
float labelPrimalUpdate = 0;
// This will be used to update the unregularized bias corresponding to the label class.
float labelAdjustment = 0;
// Iterates through all classes.
for (int iClass = 0; iClass < numClasses; iClass++)
{
// Skip the dual/weights/bias update for label class. Will be taken care of at the end.
if (iClass == label)
continue;
var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[iClass]);
var l1IntermediateWeightsEditor =
!l1ThresholdZero ? VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[iClass]) :
default;
// Loop trials for compare-and-swap updates of duals.
// In general, concurrent update conflict to the same dual variable is rare
// if data is shuffled.
for (int numTrials = 0; numTrials < maxUpdateTrials; numTrials++)
{
long dualIndex = iClass + dualIndexInitPos;
var dual = duals[dualIndex];
var output = labelOutput + labelPrimalUpdate * normSquared - WDot(in features, in weights[iClass], biasReg[iClass] + biasUnreg[iClass]);
var dualUpdate = _loss.DualUpdate(output, 1, dual, invariant, numThreads);
// The successive over-relaxation approach to adjust the sum of dual variables (biasReg) to zero.
// Reference to details: http://stat.rutgers.edu/home/tzhang/papers/ml02_dual.pdf, pp. 16-17.
var adjustment = l1ThresholdZero ? lr * biasReg[iClass] : lr * l1IntermediateBias[iClass];
dualUpdate -= adjustment;
bool success = false;
duals.ApplyAt(dualIndex, (long index, ref float value) =>
success = Interlocked.CompareExchange(ref value, dual + dualUpdate, dual) == dual);
if (success)
{
// Note: dualConstraint[iClass] = lambdaNInv * (sum of duals[iClass])
var primalUpdate = dualUpdate * lambdaNInv * instanceWeight;
labelDual -= dual + dualUpdate;
labelPrimalUpdate += primalUpdate;
biasUnreg[iClass] += adjustment * lambdaNInv * instanceWeight;
labelAdjustment -= adjustment;
if (l1ThresholdZero)
{
VectorUtils.AddMult(in features, weightsEditor.Values, -primalUpdate);
biasReg[iClass] -= primalUpdate;
}
else
{
//Iterative shrinkage-thresholding (aka. soft-thresholding)
//Update v=denseWeights as if there's no L1
//Thresholding: if |v[j]| < threshold, turn off weights[j]
//If not, shrink: w[j] = v[i] - sign(v[j]) * threshold
l1IntermediateBias[iClass] -= primalUpdate;
if (SdcaTrainerOptions.BiasLearningRate == 0)
{
biasReg[iClass] = Math.Abs(l1IntermediateBias[iClass]) - l1Threshold > 0.0
? l1IntermediateBias[iClass] - Math.Sign(l1IntermediateBias[iClass]) * l1Threshold
: 0;
}
var featureValues = features.GetValues();
if (features.IsDense)
CpuMathUtils.SdcaL1UpdateDense(-primalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
else if (featureValues.Length > 0)
CpuMathUtils.SdcaL1UpdateSparse(-primalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
}
break;
}
}
}
// Updating with label class weights and dual variable.
duals[label + dualIndexInitPos] = labelDual;
biasUnreg[label] += labelAdjustment * lambdaNInv * instanceWeight;
if (l1ThresholdZero)
{
var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
VectorUtils.AddMult(in features, weightsEditor.Values, labelPrimalUpdate);
biasReg[label] += labelPrimalUpdate;
}
else
{
l1IntermediateBias[label] += labelPrimalUpdate;
var intermediateBias = l1IntermediateBias[label];
biasReg[label] = Math.Abs(intermediateBias) - l1Threshold > 0.0
? intermediateBias - Math.Sign(intermediateBias) * l1Threshold
: 0;
var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights[label]);
var l1IntermediateWeightsEditor = VBufferEditor.CreateFromBuffer(ref l1IntermediateWeights[label]);
var featureValues = features.GetValues();
if (features.IsDense)
CpuMathUtils.SdcaL1UpdateDense(labelPrimalUpdate, featureValues.Length, featureValues, l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
else if (featureValues.Length > 0)
CpuMathUtils.SdcaL1UpdateSparse(labelPrimalUpdate, featureValues.Length, featureValues, features.GetIndices(), l1Threshold, l1IntermediateWeightsEditor.Values, weightsEditor.Values);
}
rowCount++;
}
}
}
/// <inheritdoc/>
private protected override bool CheckConvergence(
IProgressChannel pch,
int iter,
FloatLabelCursor.Factory cursorFactory,
DualsTableBase duals,
IdToIdxLookup idToIdx,
VBuffer<float>[] weights,
VBuffer<float>[] bestWeights,
float[] biasUnreg,
float[] bestBiasUnreg,
float[] biasReg,
float[] bestBiasReg,
long count,
Double[] metrics,
ref Double bestPrimalLoss,
ref int bestIter)
{
Contracts.AssertValue(weights);
Contracts.AssertValue(duals);
int numClasses = weights.Length;
Contracts.Assert(duals.Length >= numClasses * count);
Contracts.AssertValueOrNull(idToIdx);
Contracts.Assert(Utils.Size(weights) == numClasses);
Contracts.Assert(Utils.Size(biasReg) == numClasses);
Contracts.Assert(Utils.Size(biasUnreg) == numClasses);
Contracts.Assert(Utils.Size(metrics) == 6);
var reportedValues = new Double?[metrics.Length + 1];
reportedValues[metrics.Length] = iter;
var lossSum = new CompensatedSum();
var dualLossSum = new CompensatedSum();
int numFeatures = weights[0].Length;
using (var cursor = cursorFactory.Create())
{
long row = 0;
Func<DataViewRowId, long, long> getIndexFromIdAndRow = GetIndexFromIdAndRowGetter(idToIdx, biasReg.Length);
// Iterates through data to compute loss function.
while (cursor.MoveNext())
{
Host.CheckAlive();
var instanceWeight = GetInstanceWeight(cursor);
var features = cursor.Features;
var label = (int)cursor.Label;
var labelOutput = WDot(in features, in weights[label], biasReg[label] + biasUnreg[label]);
Double subLoss = 0;
Double subDualLoss = 0;
long idx = getIndexFromIdAndRow(cursor.Id, row);
long dualIndex = idx * numClasses;
for (int iClass = 0; iClass < numClasses; iClass++)
{
if (iClass == label)
{
dualIndex++;
continue;
}
var currentClassOutput = WDot(in features, in weights[iClass], biasReg[iClass] + biasUnreg[iClass]);
subLoss += _loss.Loss(labelOutput - currentClassOutput, 1);
Contracts.Assert(dualIndex == iClass + idx * numClasses);
var dual = duals[dualIndex++];
subDualLoss += _loss.DualLoss(1, dual);
}
lossSum.Add(subLoss * instanceWeight);
dualLossSum.Add(subDualLoss * instanceWeight);
row++;
}
Host.Assert(idToIdx == null || row * numClasses == duals.Length);
}
Contracts.Assert(SdcaTrainerOptions.L2Regularization.HasValue);
Contracts.Assert(SdcaTrainerOptions.L1Regularization.HasValue);
Double l2Const = SdcaTrainerOptions.L2Regularization.Value;
Double l1Threshold = SdcaTrainerOptions.L1Regularization.Value;
Double weightsL1Norm = 0;
Double weightsL2NormSquared = 0;
Double biasRegularizationAdjustment = 0;
for (int iClass = 0; iClass < numClasses; iClass++)
{
weightsL1Norm += VectorUtils.L1Norm(in weights[iClass]) + Math.Abs(biasReg[iClass]);
weightsL2NormSquared += VectorUtils.NormSquared(weights[iClass]) + biasReg[iClass] * biasReg[iClass];
biasRegularizationAdjustment += biasReg[iClass] * biasUnreg[iClass];
}
Double l1Regularizer = SdcaTrainerOptions.L1Regularization.Value * l2Const * weightsL1Norm;
var l2Regularizer = l2Const * weightsL2NormSquared * 0.5;
var newLoss = lossSum.Sum / count + l2Regularizer + l1Regularizer;
var newDualLoss = dualLossSum.Sum / count - l2Regularizer - l2Const * biasRegularizationAdjustment;
var dualityGap = newLoss - newDualLoss;
metrics[(int)MetricKind.Loss] = newLoss;
metrics[(int)MetricKind.DualLoss] = newDualLoss;
metrics[(int)MetricKind.DualityGap] = dualityGap;
metrics[(int)MetricKind.BiasUnreg] = biasUnreg[0];
metrics[(int)MetricKind.BiasReg] = biasReg[0];
metrics[(int)MetricKind.L1Sparsity] = SdcaTrainerOptions.L1Regularization == 0 ? 1 : weights.Sum(
weight => weight.GetValues().Count(w => w != 0)) / (numClasses * numFeatures);
bool converged = dualityGap / newLoss < SdcaTrainerOptions.ConvergenceTolerance;
if (metrics[(int)MetricKind.Loss] < bestPrimalLoss)
{
for (int iClass = 0; iClass < numClasses; iClass++)
{
// Maintain a copy of weights and bias with best primal loss thus far.
// This is some extra work and uses extra memory, but it seems worth doing it.
// REVIEW: Sparsify bestWeights?
weights[iClass].CopyTo(ref bestWeights[iClass]);
bestBiasReg[iClass] = biasReg[iClass];
bestBiasUnreg[iClass] = biasUnreg[iClass];
}
bestPrimalLoss = metrics[(int)MetricKind.Loss];
bestIter = iter;
}
for (int i = 0; i < metrics.Length; i++)
reportedValues[i] = metrics[i];
if (pch != null)
pch.Checkpoint(reportedValues);
return converged;
}
private protected override void CheckLabel(RoleMappedData examples, out int weightSetCount)
{
examples.CheckMulticlassLabel(out weightSetCount);
}
private protected override float[] InitializeFeatureNormSquared(int length)
{
Contracts.Assert(0 < length && length <= Utils.ArrayMaxSize);
return new float[length];
}
private protected override float GetInstanceWeight(FloatLabelCursor cursor)
{
return cursor.Weight;
}
}
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> to predict a target using a maximum entropy multiclass classifier.
/// The trained model <see cref="MaximumEntropyModelParameters"/> produces probabilities of classes.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [SdcaMaximumEntropy](xref:Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,System.Nullable{System.Single},System.Nullable{System.Single},System.Nullable{System.Int32})) or
/// [SdcaMaximumEntropy(Options)](xref:Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,Microsoft.ML.Trainers.SdcaMaximumEntropyMulticlassTrainer.Options)).
///
/// [!include[io](~/../docs/samples/docs/api-reference/io-columns-multiclass-classification.md)]
///
/// ### Trainer Characteristics
/// | | |
/// | -- | -- |
/// | Machine learning task | Multiclass classification |
/// | Is normalization required? | Yes |
/// | Is caching required? | No |
/// | Required NuGet in addition to Microsoft.ML | None |
/// | Exportable to ONNX | Yes |
///
/// ### Scoring Function
/// This trains a linear model to solve multiclass classification problems.
/// Assume that the number of classes is $m$ and number of features is $n$.
/// It assigns the $c$-th class a coefficient vector $\textbf{w}\_c \in {\mathbb R}^n$ and a bias $b_c \in {\mathbb R}$, for $c=1,\dots,m$.
/// Given a feature vector $\textbf{x} \in {\mathbb R}^n$, the $c$-th class's score would be $\tilde{P}(c | \textbf{x}) = \frac{ e^{\hat{y}^c} }{ \sum\_{c' = 1}^m e^{\hat{y}^{c'}} }$, where $\hat{y}^c = \textbf{w}\_c^T \textbf{x} + b_c$.
/// Note that $\tilde{P}(c | \textbf{x})$ is the probability of observing class $c$ when the feature vector is $\textbf{x}$.
///
/// ### Training Algorithm Details
/// See the documentation of [SdcaMulticlassTrainerBase](xref:Microsoft.ML.Trainers.SdcaMulticlassTrainerBase`1).
///
/// Check the See Also section for links to usage examples.
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(MulticlassClassificationCatalog.MulticlassClassificationTrainers, SdcaMaximumEntropyMulticlassTrainer.Options)"/>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, float?, float?, int?)"/>
/// <seealso cref="Microsoft.ML.Trainers.SdcaMaximumEntropyMulticlassTrainer.Options"/>
public sealed class SdcaMaximumEntropyMulticlassTrainer : SdcaMulticlassTrainerBase<MaximumEntropyModelParameters>
{
/// <summary>
/// <see cref="Options"/> for <see cref="SdcaMaximumEntropyMulticlassTrainer"/> as used in
/// <see cref="Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, float?, float?, int?)"/>
/// </summary>
public sealed class Options : MulticlassOptions
{
}
internal SdcaMaximumEntropyMulticlassTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null)
: base(env, labelColumn: labelColumn, featureColumn: featureColumn, weights: weights, loss: new LogLoss(),
l2Const: l2Const, l1Threshold: l1Threshold, maxIterations: maxIterations)
{
}
internal SdcaMaximumEntropyMulticlassTrainer(IHostEnvironment env, Options options,
string featureColumn, string labelColumn, string weightColumn = null)
: base(env, options: options, featureColumn: featureColumn, labelColumn: labelColumn, weightColumn: weightColumn)
{
}
internal SdcaMaximumEntropyMulticlassTrainer(IHostEnvironment env, Options options)
: base(env, options)
{
}
private protected override MaximumEntropyModelParameters CreatePredictor(VBuffer<float>[] weights, float[] bias)
{
Host.CheckValue(weights, nameof(weights));
Host.CheckValue(bias, nameof(bias));
Host.CheckParam(weights.Length > 0, nameof(weights));
Host.CheckParam(weights.Length == bias.Length, nameof(weights));
return new MaximumEntropyModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null);
}
private protected override MulticlassPredictionTransformer<MaximumEntropyModelParameters> MakeTransformer(
MaximumEntropyModelParameters model, DataViewSchema trainSchema) =>
new MulticlassPredictionTransformer<MaximumEntropyModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
}
/// <summary>
/// The<see cref="IEstimator{TTransformer}"/> to predict a target using a linear multiclass classifier.
/// The trained model <see cref="LinearMulticlassModelParameters"/> produces probabilities of classes.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [SdcaMaximumEntropy](xref:Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,System.Nullable{System.Single},System.Nullable{System.Single},System.Nullable{System.Int32})) or
/// [SdcaMaximumEntropy(Options)](xref:Microsoft.ML.StandardTrainersCatalog.SdcaMaximumEntropy(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,Microsoft.ML.Trainers.SdcaMaximumEntropyMulticlassTrainer.Options)).
///
/// [!include[io](~/../docs/samples/docs/api-reference/io-columns-multiclass-classification.md)]
///
/// ### Trainer Characteristics
/// | | |
/// | -- | -- |
/// | Machine learning task | Multiclass classification |
/// | Is normalization required? | Yes |
/// | Is caching required? | No |
/// | Required NuGet in addition to Microsoft.ML | None |
/// | Exportable to ONNX | Yes |
///
/// ### Scoring Function
/// This trains a linear model to solve multiclass classification problems.
/// Assume that the number of classes is $m$ and number of features is $n$.
/// It assigns the $c$-th class a coefficient vector $\textbf{w}_c \in {\mathbb R}^n$ and a bias $b_c \in {\mathbb R}$, for $c=1,\dots,m$.
/// Given a feature vector $\textbf{x} \in {\mathbb R}^n$, the $c$-th class's score would be $\hat{y}^c = \textbf{w}_c^T \textbf{x} + b_c$.
/// Note that the $c$-th value in the output score column is just $\hat{y}^c$.
///
/// ### Training Algorithm Details
/// See the documentation of [SdcaMulticlassTrainerBase](xref:Microsoft.ML.Trainers.SdcaMulticlassTrainerBase).
///
/// Check the See Also section for links to usage examples.
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.SdcaNonCalibrated(MulticlassClassificationCatalog.MulticlassClassificationTrainers, SdcaNonCalibratedMulticlassTrainer.Options)"/>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.SdcaNonCalibrated(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, ISupportSdcaClassificationLoss, float?, float?, int?)"/>
/// <seealso cref="Microsoft.ML.Trainers.SdcaNonCalibratedMulticlassTrainer.Options"/>
public sealed class SdcaNonCalibratedMulticlassTrainer : SdcaMulticlassTrainerBase<LinearMulticlassModelParameters>
{
/// <summary>
/// <see cref="Options"/> for <see cref="SdcaNonCalibratedMulticlassTrainer"/> as used in
/// <see cref="Microsoft.ML.StandardTrainersCatalog.SdcaNonCalibrated(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, ISupportSdcaClassificationLoss, float?, float?, int?)"/>.
/// </summary>
public sealed class Options : MulticlassOptions
{
/// <summary>
/// Loss function minimized by this trainer.
/// </summary>
/// <value>
/// If unspecified, <see cref="LogLoss"/> will be used.
/// </value>
public ISupportSdcaClassificationLoss Loss
{
get { return InternalLoss; }
set { InternalLoss = value; }
}
}
internal SdcaNonCalibratedMulticlassTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
ISupportSdcaClassificationLoss loss = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null)
: base(env, labelColumn: labelColumn, featureColumn: featureColumn, weights: weights, loss: loss,
l2Const: l2Const, l1Threshold: l1Threshold, maxIterations: maxIterations)
{
}
internal SdcaNonCalibratedMulticlassTrainer(IHostEnvironment env, Options options,
string featureColumn, string labelColumn, string weightColumn = null)
: base(env, options: options, featureColumn: featureColumn, labelColumn: labelColumn, weightColumn: weightColumn)
{
}
internal SdcaNonCalibratedMulticlassTrainer(IHostEnvironment env, Options options)
: base(env, options)
{
}
private protected override LinearMulticlassModelParameters CreatePredictor(VBuffer<float>[] weights, float[] bias)
{
Host.CheckValue(weights, nameof(weights));
Host.CheckValue(bias, nameof(bias));
Host.CheckParam(weights.Length > 0, nameof(weights));
Host.CheckParam(weights.Length == bias.Length, nameof(weights));
return new LinearMulticlassModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null);
}
private protected override MulticlassPredictionTransformer<LinearMulticlassModelParameters> MakeTransformer(
LinearMulticlassModelParameters model, DataViewSchema trainSchema) =>
new MulticlassPredictionTransformer<LinearMulticlassModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
}
/// <summary>
/// The Entry Point for SDCA multiclass.
/// </summary>
internal static partial class Sdca
{
[TlcModule.EntryPoint(Name = "Trainers.StochasticDualCoordinateAscentClassifier",
Desc = SdcaMaximumEntropyMulticlassTrainer.Summary,
UserName = SdcaMaximumEntropyMulticlassTrainer.UserNameValue,
ShortName = SdcaMaximumEntropyMulticlassTrainer.ShortName)]
public static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, SdcaMaximumEntropyMulticlassTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainSDCA");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return TrainerEntryPointsUtils.Train<SdcaMaximumEntropyMulticlassTrainer.Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
() => new SdcaMaximumEntropyMulticlassTrainer(host, input),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName));
}
}
}
|