|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(typeof(MulticlassClassificationEvaluator), typeof(MulticlassClassificationEvaluator), typeof(MulticlassClassificationEvaluator.Arguments), typeof(SignatureEvaluator),
"Multi-Class Classifier Evaluator", MulticlassClassificationEvaluator.LoadName, "MultiClassClassifier", "MultiClass")]
[assembly: LoadableClass(typeof(MulticlassClassificationMamlEvaluator), typeof(MulticlassClassificationMamlEvaluator), typeof(MulticlassClassificationMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator),
"Multi-Class Classifier Evaluator", MulticlassClassificationEvaluator.LoadName, "MultiClassClassifier", "MultiClass")]
// This is for deserialization of the per-instance transform.
[assembly: LoadableClass(typeof(MulticlassPerInstanceEvaluator), null, typeof(SignatureLoadRowMapper),
"", MulticlassPerInstanceEvaluator.LoaderSignature)]
namespace Microsoft.ML.Data
{
[BestFriend]
internal sealed class MulticlassClassificationEvaluator : RowToRowEvaluatorBase<MulticlassClassificationEvaluator.Aggregator>
{
public sealed class Arguments
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Output top K accuracy", ShortName = "topkacc")]
public int? OutputTopKAcc;
[Argument(ArgumentType.AtMostOnce, HelpText = "Use the textual class label names in the report, if available", ShortName = "n")]
public bool Names = true;
}
public const string AccuracyMicro = "Accuracy(micro-avg)";
public const string AccuracyMacro = "Accuracy(macro-avg)";
public const string TopKAccuracy = "Top K accuracy";
public const string AllTopKAccuracy = "Top K accuracies";
public const string PerClassLogLoss = "Per class log-loss";
public const string LogLoss = "Log-loss";
public const string LogLossReduction = "Log-loss reduction";
public enum Metrics
{
[EnumValueDisplay(MulticlassClassificationEvaluator.AccuracyMicro)]
AccuracyMicro,
[EnumValueDisplay(MulticlassClassificationEvaluator.AccuracyMacro)]
AccuracyMacro,
[EnumValueDisplay(MulticlassClassificationEvaluator.LogLoss)]
LogLoss,
[EnumValueDisplay(MulticlassClassificationEvaluator.LogLossReduction)]
LogLossReduction,
}
internal const string LoadName = "MultiClassClassifierEvaluator";
private readonly int? _outputTopKAcc;
private readonly bool _names;
public MulticlassClassificationEvaluator(IHostEnvironment env, Arguments args)
: base(env, LoadName)
{
Host.AssertValue(args, "args");
Host.CheckUserArg(args.OutputTopKAcc == null || args.OutputTopKAcc > 0, nameof(args.OutputTopKAcc));
_outputTopKAcc = args.OutputTopKAcc;
_names = args.Names;
}
private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
var scoreType = score.Type as VectorDataViewType;
if (scoreType == null || scoreType.Size < 2 || scoreType.ItemType != NumberDataViewType.Single)
throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "vector of two or more items of type Single", score.Type.ToString());
Host.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column");
var labelType = schema.Label.Value.Type;
if (labelType != NumberDataViewType.Single && labelType.GetKeyCount() <= 0)
throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "Single or Key", labelType.ToString());
}
private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
int numClasses = score.Type.GetVectorSize();
Host.Assert(numClasses > 0);
var classNames = GetClassNames(schema);
return new Aggregator(Host, classNames, numClasses, schema.Weight != null, _outputTopKAcc, stratName);
}
private ReadOnlyMemory<char>[] GetClassNames(RoleMappedSchema schema)
{
ReadOnlyMemory<char>[] names;
// Get the label names from the score column if they exist, or use the default names.
var scoreInfo = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
var mdType = schema.Schema[scoreInfo.Index].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
var labelNames = default(VBuffer<ReadOnlyMemory<char>>);
if (mdType != null && mdType.IsKnownSize && mdType.ItemType is TextDataViewType)
{
schema.Schema[scoreInfo.Index].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames);
names = new ReadOnlyMemory<char>[labelNames.Length];
labelNames.CopyTo(names);
}
else
{
var score = schema.GetColumns(AnnotationUtils.Const.ScoreValueKind.Score);
Host.Assert(Utils.Size(score) == 1);
int numClasses = score[0].Type.GetVectorSize();
Host.Assert(numClasses > 0);
names = Enumerable.Range(0, numClasses).Select(i => i.ToString().AsMemory()).ToArray();
}
return names;
}
private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema)
{
Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column");
var scoreInfo = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
int numClasses = scoreInfo.Type.GetVectorSize();
return new MulticlassPerInstanceEvaluator(Host, schema.Schema, scoreInfo, schema.Label.Value.Name);
}
public override IEnumerable<MetricColumn> GetOverallMetricColumns()
{
yield return new MetricColumn("AccuracyMicro", AccuracyMicro);
yield return new MetricColumn("AccuracyMacro", AccuracyMacro);
yield return new MetricColumn("TopKAccuracy", TopKAccuracy);
yield return new MetricColumn("LogLoss<class name>", PerClassLogLoss, MetricColumn.Objective.Minimize,
isVector: true, namePattern: new Regex(string.Format(@"^{0}(?<class>.+)", LogLoss), RegexOptions.IgnoreCase),
groupName: "class", nameFormat: string.Format("{0} (class {{0}})", PerClassLogLoss));
yield return new MetricColumn("LogLoss", LogLoss, MetricColumn.Objective.Minimize);
yield return new MetricColumn("LogLossReduction", LogLossReduction);
}
private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action<uint, ReadOnlyMemory<char>, Aggregator> addAgg, out Func<Dictionary<string, IDataView>> consolidate)
{
var stratCol = new List<uint>();
var stratVal = new List<ReadOnlyMemory<char>>();
var isWeighted = new List<bool>();
var microAcc = new List<double>();
var macroAcc = new List<double>();
var logLoss = new List<double>();
var logLossRed = new List<double>();
var topKAcc = new List<double>();
var allTopK = new List<double[]>();
var perClassLogLoss = new List<double[]>();
var counts = new List<double[]>();
var weights = new List<double[]>();
var confStratCol = new List<uint>();
var confStratVal = new List<ReadOnlyMemory<char>>();
bool hasStrats = Utils.Size(dictionaries) > 0;
bool hasWeight = aggregator.Weighted;
addAgg =
(stratColKey, stratColVal, agg) =>
{
Host.Check(agg.Weighted == hasWeight, "All aggregators must either be weighted or unweighted");
Host.Check((agg.UnweightedCounters.OutputTopKAcc > 0) == (aggregator.UnweightedCounters.OutputTopKAcc > 0),
"All aggregators must either compute top-k accuracy or not compute top-k accuracy");
stratCol.Add(stratColKey);
stratVal.Add(stratColVal);
isWeighted.Add(false);
microAcc.Add(agg.UnweightedCounters.MicroAvgAccuracy);
macroAcc.Add(agg.UnweightedCounters.MacroAvgAccuracy);
logLoss.Add(agg.UnweightedCounters.LogLoss);
logLossRed.Add(agg.UnweightedCounters.Reduction);
if (agg.UnweightedCounters.OutputTopKAcc > 0)
{
topKAcc.Add(agg.UnweightedCounters.TopKAccuracy);
allTopK.Add(agg.UnweightedCounters.AllTopKAccuracy);
}
perClassLogLoss.Add(agg.UnweightedCounters.PerClassLogLoss);
confStratCol.AddRange(agg.UnweightedCounters.ConfusionTable.Select(x => stratColKey));
confStratVal.AddRange(agg.UnweightedCounters.ConfusionTable.Select(x => stratColVal));
counts.AddRange(agg.UnweightedCounters.ConfusionTable);
if (agg.Weighted)
{
stratCol.Add(stratColKey);
stratVal.Add(stratColVal);
isWeighted.Add(true);
microAcc.Add(agg.WeightedCounters.MicroAvgAccuracy);
macroAcc.Add(agg.WeightedCounters.MacroAvgAccuracy);
logLoss.Add(agg.WeightedCounters.LogLoss);
logLossRed.Add(agg.WeightedCounters.Reduction);
if (agg.WeightedCounters.OutputTopKAcc > 0)
{
topKAcc.Add(agg.WeightedCounters.TopKAccuracy);
allTopK.Add(agg.WeightedCounters.AllTopKAccuracy);
}
perClassLogLoss.Add(agg.WeightedCounters.PerClassLogLoss);
weights.AddRange(agg.WeightedCounters.ConfusionTable);
}
};
consolidate =
() =>
{
var overallDvBldr = new ArrayDataViewBuilder(Host);
if (hasStrats)
{
overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, stratCol.ToArray());
overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, stratVal.ToArray());
}
if (hasWeight)
overallDvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BooleanDataViewType.Instance, isWeighted.ToArray());
overallDvBldr.AddColumn(AccuracyMicro, NumberDataViewType.Double, microAcc.ToArray());
overallDvBldr.AddColumn(AccuracyMacro, NumberDataViewType.Double, macroAcc.ToArray());
overallDvBldr.AddColumn(LogLoss, NumberDataViewType.Double, logLoss.ToArray());
overallDvBldr.AddColumn(LogLossReduction, NumberDataViewType.Double, logLossRed.ToArray());
if (aggregator.UnweightedCounters.OutputTopKAcc > 0)
{
overallDvBldr.AddColumn(TopKAccuracy, NumberDataViewType.Double, topKAcc.ToArray());
ValueGetter<VBuffer<ReadOnlyMemory<char>>> getKSlotNames =
(ref VBuffer<ReadOnlyMemory<char>> dst) =>
dst = new VBuffer<ReadOnlyMemory<char>>(allTopK.First().Length, Enumerable.Range(1, allTopK.First().Length).Select(i => new ReadOnlyMemory<char>(i.ToString().ToCharArray())).ToArray());
overallDvBldr.AddColumn(AllTopKAccuracy, getKSlotNames, NumberDataViewType.Double, allTopK.ToArray());
}
overallDvBldr.AddColumn(PerClassLogLoss, aggregator.GetSlotNames, NumberDataViewType.Double, perClassLogLoss.ToArray());
var confDvBldr = new ArrayDataViewBuilder(Host);
if (hasStrats)
{
confDvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, confStratCol.ToArray());
confDvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, confStratVal.ToArray());
}
ValueGetter<VBuffer<ReadOnlyMemory<char>>> getSlotNames =
(ref VBuffer<ReadOnlyMemory<char>> dst) =>
dst = new VBuffer<ReadOnlyMemory<char>>(aggregator.ClassNames.Length, aggregator.ClassNames);
confDvBldr.AddColumn(MetricKinds.ColumnNames.Count, getSlotNames, NumberDataViewType.Double, counts.ToArray());
if (hasWeight)
confDvBldr.AddColumn(MetricKinds.ColumnNames.Weight, getSlotNames, NumberDataViewType.Double, weights.ToArray());
var result = new Dictionary<string, IDataView>
{
{ MetricKinds.OverallMetrics, overallDvBldr.GetDataView() },
{ MetricKinds.ConfusionMatrix, confDvBldr.GetDataView() }
};
return result;
};
}
public sealed class Aggregator : AggregatorBase
{
public sealed class Counters
{
private readonly int _numClasses;
public readonly int? OutputTopKAcc;
private double _totalLogLoss;
private double _numInstances;
private double _numCorrect;
private readonly double[] _sumWeightsOfClass;
private readonly double[] _totalPerClassLogLoss;
private readonly double[] _seenRanks;
public readonly double[][] ConfusionTable;
public double MicroAvgAccuracy { get { return _numInstances > 0 ? _numCorrect / _numInstances : 0; } }
public double MacroAvgAccuracy
{
get
{
if (_numInstances == 0)
return 0;
double macroAvgAccuracy = 0;
int countOfNonEmptyClasses = 0;
for (int i = 0; i < _numClasses; ++i)
{
if (_sumWeightsOfClass[i] > 0)
{
countOfNonEmptyClasses++;
macroAvgAccuracy += ConfusionTable[i][i] / _sumWeightsOfClass[i];
}
}
return countOfNonEmptyClasses > 0 ? macroAvgAccuracy / countOfNonEmptyClasses : 0;
}
}
public double LogLoss { get { return _numInstances > 0 ? _totalLogLoss / _numInstances : 0; } }
public double Reduction
{
get
{
// reduction -- prior log loss is entropy
double entropy = 0;
for (int i = 0; i < _numClasses; ++i)
{
if (_sumWeightsOfClass[i] != 0)
entropy += _sumWeightsOfClass[i] * Math.Log(_sumWeightsOfClass[i] / _numInstances);
}
entropy /= -_numInstances;
return (entropy - LogLoss) / entropy;
}
}
public double TopKAccuracy => !(OutputTopKAcc is null) ? AllTopKAccuracy[OutputTopKAcc.Value - 1] : 0d;
public double[] AllTopKAccuracy => CumulativeSum(_seenRanks.Take(OutputTopKAcc ?? 0).Select(l => l / _numInstances)).ToArray();
// The per class average log loss is calculated by dividing the weighted sum of the log loss of examples
// in each class by the total weight of examples in that class.
public double[] PerClassLogLoss
{
get
{
var res = new double[_totalPerClassLogLoss.Length];
for (int i = 0; i < _totalPerClassLogLoss.Length; i++)
res[i] = _sumWeightsOfClass[i] > 0 ? _totalPerClassLogLoss[i] / _sumWeightsOfClass[i] : 0;
return res;
}
}
public Counters(int numClasses, int? outputTopKAcc)
{
_numClasses = numClasses;
OutputTopKAcc = outputTopKAcc;
_sumWeightsOfClass = new double[numClasses];
_totalPerClassLogLoss = new double[numClasses];
ConfusionTable = new double[numClasses][];
for (int i = 0; i < ConfusionTable.Length; i++)
ConfusionTable[i] = new double[numClasses];
_seenRanks = new double[numClasses + 1];
}
public void Update(int seenRank, int assigned, double loglossCurr, int label, float weight)
{
_numInstances += weight;
if (label < _numClasses)
_sumWeightsOfClass[label] += weight;
_totalLogLoss += loglossCurr * weight;
if (label < _numClasses)
_totalPerClassLogLoss[label] += loglossCurr * weight;
_seenRanks[seenRank] += weight;
if (seenRank == 0) // Prediction matched label
{
_numCorrect += weight;
ConfusionTable[label][label] += weight;
}
else if (label < _numClasses)
{
ConfusionTable[label][assigned] += weight;
}
}
private static IEnumerable<double> CumulativeSum(IEnumerable<double> s)
{
double sum = 0;
foreach (var x in s)
{
sum += x;
yield return sum;
}
}
}
private ValueGetter<float> _labelGetter;
private ValueGetter<VBuffer<float>> _scoreGetter;
private ValueGetter<float> _weightGetter;
private VBuffer<float> _scores;
private readonly float[] _scoresArr;
private const float Epsilon = (float)1e-15;
public readonly Counters UnweightedCounters;
public readonly Counters WeightedCounters;
public readonly bool Weighted;
private long _numUnknownClassInstances;
private long _numNegOrNonIntegerLabels;
public readonly ReadOnlyMemory<char>[] ClassNames;
public Aggregator(IHostEnvironment env, ReadOnlyMemory<char>[] classNames, int scoreVectorSize, bool weighted, int? outputTopKAcc, string stratName)
: base(env, stratName)
{
Host.Assert(outputTopKAcc == null || outputTopKAcc > 0);
Host.Assert(scoreVectorSize > 0);
Host.Assert(Utils.Size(classNames) == scoreVectorSize);
_scoresArr = new float[scoreVectorSize];
UnweightedCounters = new Counters(scoreVectorSize, outputTopKAcc);
Weighted = weighted;
WeightedCounters = Weighted ? new Counters(scoreVectorSize, outputTopKAcc) : null;
ClassNames = classNames;
}
internal override void InitializeNextPass(DataViewRow row, RoleMappedSchema schema)
{
Host.Assert(PassNum < 1);
Host.Assert(schema.Label.HasValue);
var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
Host.Assert(score.Type.GetVectorSize() == _scoresArr.Length);
_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
_scoreGetter = row.GetGetter<VBuffer<float>>(score);
Host.AssertValue(_labelGetter);
Host.AssertValue(_scoreGetter);
if (schema.Weight.HasValue)
_weightGetter = row.GetGetter<float>(schema.Weight.Value);
}
public override void ProcessRow()
{
float label = 0;
_labelGetter(ref label);
if (float.IsNaN(label))
{
NumUnlabeledInstances++;
return;
}
if (label < 0 || label != (int)label)
{
_numNegOrNonIntegerLabels++;
return;
}
_scoreGetter(ref _scores);
Host.Check(_scores.Length == _scoresArr.Length);
if (VBufferUtils.HasNaNs(in _scores) || VBufferUtils.HasNonFinite(in _scores))
{
NumBadScores++;
return;
}
_scores.CopyTo(_scoresArr);
float weight = 1;
if (_weightGetter != null)
{
_weightGetter(ref weight);
if (!FloatUtils.IsFinite(weight))
{
NumBadWeights++;
weight = 1;
}
}
var intLabel = (int)label;
var wasKnownLabel = true;
// log-loss
double logloss;
if (intLabel < _scoresArr.Length)
{
// REVIEW: This assumes that the predictions are probabilities, not just relative scores
// for the classes. Is this a correct assumption?
float p = Math.Min(1, Math.Max(Epsilon, _scoresArr[intLabel]));
logloss = -Math.Log(p);
}
else
{
// Penalize logloss if the label was not seen during training
logloss = -Math.Log(Epsilon);
_numUnknownClassInstances++;
wasKnownLabel = false;
}
// Get the probability that the CORRECT label has: (best case is that it's the highest probability):
var correctProba = !wasKnownLabel ? 0 : _scoresArr[intLabel];
// Find the rank of the *correct* label (in _scoresArr[]). If the correct (ground truth) labels gets rank 0,
// it means the model assigned it the highest probability (that's ideal). Rank 1 would mean our model
// gives the real label the 2nd highest probabality, etc.
// The rank will be from 0 to N. (Not N-1). Rank N is used for unrecognized values.
//
// Tie breaking: What if we have probabilities that are equal to the correct prediction (eg, a:0.1, b:0.1,
// c:0.1, d:0.6, e:0.1 where c is the correct label).
// This actually happens a lot with some models. We handle ties by assigning rank in order of first
// appearance. In this example, we assign c the rank of 3, because d has a higher probability and a and b
// are sequentially first.
int rankOfCorrectLabel = 0;
int assigned = 0;
for (int i = 0; i < _scoresArr.Length; i++)
{
if (_scoresArr[i] > correctProba || (_scoresArr[i] == correctProba && i < intLabel))
rankOfCorrectLabel++;
// This is the assigned "prediction" of the model if it has the highest probability.
if (_scoresArr[assigned] < _scoresArr[i])
assigned = i;
}
UnweightedCounters.Update(rankOfCorrectLabel, assigned, logloss, intLabel, 1);
if (WeightedCounters != null)
WeightedCounters.Update(rankOfCorrectLabel, assigned, logloss, intLabel, weight);
}
protected override List<string> GetWarningsCore()
{
var warnings = base.GetWarningsCore();
if (_numUnknownClassInstances > 0)
{
warnings.Add(string.Format(
"Found {0} test instances with class values not seen in the training set. LogLoss is reported higher than usual because of these instances.",
_numUnknownClassInstances));
}
if (_numNegOrNonIntegerLabels > 0)
{
warnings.Add(string.Format(
"Found {0} test instances with labels that are either negative or non integers. These instances were ignored",
_numNegOrNonIntegerLabels));
}
return warnings;
}
public void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> slotNames)
{
var editor = VBufferEditor.Create(ref slotNames, ClassNames.Length);
for (int i = 0; i < ClassNames.Length; i++)
editor.Values[i] = string.Format("(class {0})", ClassNames[i]).AsMemory();
slotNames = editor.Commit();
}
}
/// <summary>
/// Evaluates scored multiclass classification data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these outputs.</returns>
public MulticlassClassificationMetrics Evaluate(IDataView data, string label, string score, string predictedLabel)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
var roles = new RoleMappedData(data, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(label),
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score),
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));
var resultDict = ((IEvaluator)this).Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];
MulticlassClassificationMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new MulticlassClassificationMetrics(Host, cursor, _outputTopKAcc ?? 0, confusionMatrix);
moved = cursor.MoveNext();
Host.Assert(!moved);
}
return result;
}
}
internal sealed class MulticlassPerInstanceEvaluator : PerInstanceEvaluatorBase
{
public const string LoaderSignature = "MulticlassPerInstance";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "MLTIINST",
//verWrittenCur: 0x00010001, // Initial
verWrittenCur: 0x00010002, // Serialize the class names
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(MulticlassPerInstanceEvaluator).Assembly.FullName);
}
private const int AssignedCol = 0;
private const int LogLossCol = 1;
private const int SortedScoresCol = 2;
private const int SortedClassesCol = 3;
private const uint VerInitial = 0x00010001;
public const string Assigned = "Assigned";
public const string LogLoss = "Log-loss";
public const string SortedScores = "SortedScores";
public const string SortedClasses = "SortedClasses";
private const float Epsilon = (float)1e-15;
private readonly int _numClasses;
private readonly ReadOnlyMemory<char>[] _classNames;
private readonly DataViewType[] _types;
public MulticlassPerInstanceEvaluator(IHostEnvironment env, DataViewSchema schema, DataViewSchema.Column scoreColumn, string labelCol)
: base(env, schema, scoreColumn.Name, labelCol)
{
CheckInputColumnTypes(schema);
_numClasses = scoreColumn.Type.GetVectorSize();
_types = new DataViewType[4];
if (schema[ScoreIndex].HasSlotNames(_numClasses))
{
var classNames = default(VBuffer<ReadOnlyMemory<char>>);
schema[(int)ScoreIndex].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref classNames);
_classNames = new ReadOnlyMemory<char>[_numClasses];
classNames.CopyTo(_classNames);
}
else
_classNames = Utils.BuildArray(_numClasses, i => i.ToString().AsMemory());
var key = new KeyDataViewType(typeof(uint), _numClasses);
_types[AssignedCol] = key;
_types[LogLossCol] = NumberDataViewType.Double;
_types[SortedScoresCol] = new VectorDataViewType(NumberDataViewType.Single, _numClasses);
_types[SortedClassesCol] = new VectorDataViewType(key, _numClasses);
}
private MulticlassPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
: base(env, ctx, schema)
{
CheckInputColumnTypes(schema);
// *** Binary format **
// base
// int: number of classes
// int[]: Ids of the class names
_numClasses = ctx.Reader.ReadInt32();
Host.CheckDecode(_numClasses > 0);
if (ctx.Header.ModelVerWritten > VerInitial)
{
_classNames = new ReadOnlyMemory<char>[_numClasses];
for (int i = 0; i < _numClasses; i++)
_classNames[i] = ctx.LoadNonEmptyString().AsMemory();
}
else
_classNames = Utils.BuildArray(_numClasses, i => i.ToString().AsMemory());
_types = new DataViewType[4];
var key = new KeyDataViewType(typeof(uint), _numClasses);
_types[AssignedCol] = key;
_types[LogLossCol] = NumberDataViewType.Double;
_types[SortedScoresCol] = new VectorDataViewType(NumberDataViewType.Single, _numClasses);
_types[SortedClassesCol] = new VectorDataViewType(key, _numClasses);
}
public static MulticlassPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new MulticlassPerInstanceEvaluator(env, ctx, schema);
}
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format **
// base
// int: number of classes
// int[]: Ids of the class names
base.SaveModel(ctx);
Host.Assert(_numClasses > 0);
ctx.Writer.Write(_numClasses);
for (int i = 0; i < _numClasses; i++)
ctx.SaveNonEmptyString(_classNames[i].ToString());
}
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
Host.Assert(ScoreIndex >= 0);
Host.Assert(LabelIndex >= 0);
// The score column is needed if any of the outputs are active. The label column is needed only
// if the log-loss output is active.
return
col =>
col == LabelIndex && activeOutput(LogLossCol) ||
col == ScoreIndex && (activeOutput(AssignedCol) || activeOutput(SortedScoresCol) ||
activeOutput(SortedClassesCol) || activeOutput(LogLossCol));
}
private protected override Delegate[] CreateGettersCore(DataViewRow input, Func<int, bool> activeCols, out Action disposer)
{
disposer = null;
var getters = new Delegate[4];
if (!activeCols(AssignedCol) && !activeCols(SortedClassesCol) && !activeCols(SortedScoresCol) && !activeCols(LogLossCol))
return getters;
long cachedPosition = -1;
VBuffer<float> scores = default(VBuffer<float>);
float label = 0;
var scoresArr = new float[_numClasses];
int[] sortedIndices = new int[_numClasses];
var labelGetter = activeCols(LogLossCol) ? RowCursorUtils.GetLabelGetter(input, LabelIndex) :
(ref float dst) => dst = float.NaN;
var scoreGetter = input.GetGetter<VBuffer<float>>(input.Schema[ScoreIndex]);
Action updateCacheIfNeeded =
() =>
{
if (cachedPosition != input.Position)
{
labelGetter(ref label);
scoreGetter(ref scores);
scores.CopyTo(scoresArr);
int j = 0;
foreach (var index in Enumerable.Range(0, scoresArr.Length).OrderByDescending(i => scoresArr[i]))
sortedIndices[j++] = index;
cachedPosition = input.Position;
}
};
if (activeCols(AssignedCol))
{
ValueGetter<uint> assignedFn =
(ref uint dst) =>
{
updateCacheIfNeeded();
dst = (uint)sortedIndices[0] + 1;
};
getters[AssignedCol] = assignedFn;
}
if (activeCols(SortedScoresCol))
{
ValueGetter<VBuffer<float>> topKScoresFn =
(ref VBuffer<float> dst) =>
{
updateCacheIfNeeded();
var editor = VBufferEditor.Create(ref dst, _numClasses);
for (int i = 0; i < _numClasses; i++)
editor.Values[i] = scores.GetItemOrDefault(sortedIndices[i]);
dst = editor.Commit();
};
getters[SortedScoresCol] = topKScoresFn;
}
if (activeCols(SortedClassesCol))
{
ValueGetter<VBuffer<uint>> topKClassesFn =
(ref VBuffer<uint> dst) =>
{
updateCacheIfNeeded();
var editor = VBufferEditor.Create(ref dst, _numClasses);
for (int i = 0; i < _numClasses; i++)
editor.Values[i] = (uint)sortedIndices[i] + 1;
dst = editor.Commit();
};
getters[SortedClassesCol] = topKClassesFn;
}
if (activeCols(LogLossCol))
{
ValueGetter<double> logLossFn =
(ref double dst) =>
{
updateCacheIfNeeded();
if (float.IsNaN(label))
{
dst = double.NaN;
return;
}
int intLabel = (int)label;
if (intLabel < _numClasses)
{
float p = Math.Min(1, Math.Max(Epsilon, scoresArr[intLabel]));
dst = -Math.Log(p);
return;
}
// Penalize logloss if the label was not seen during training
dst = -Math.Log(Epsilon);
};
getters[LogLossCol] = logLossFn;
}
return getters;
}
private protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var infos = new DataViewSchema.DetachedColumn[4];
var assignedColKeyValues = new DataViewSchema.Annotations.Builder();
assignedColKeyValues.AddKeyValues(_numClasses, TextDataViewType.Instance, CreateKeyValueGetter());
infos[AssignedCol] = new DataViewSchema.DetachedColumn(Assigned, _types[AssignedCol], assignedColKeyValues.ToAnnotations());
infos[LogLossCol] = new DataViewSchema.DetachedColumn(LogLoss, _types[LogLossCol], null);
var sortedScores = new DataViewSchema.Annotations.Builder();
sortedScores.AddSlotNames(_numClasses, CreateSlotNamesGetter(_numClasses, "Score"));
var sortedClasses = new DataViewSchema.Annotations.Builder();
sortedClasses.AddSlotNames(_numClasses, CreateSlotNamesGetter(_numClasses, "Class"));
sortedClasses.AddKeyValues(_numClasses, TextDataViewType.Instance, CreateKeyValueGetter());
infos[SortedScoresCol] = new DataViewSchema.DetachedColumn(SortedScores, _types[SortedScoresCol], sortedScores.ToAnnotations());
infos[SortedClassesCol] = new DataViewSchema.DetachedColumn(SortedClasses, _types[SortedClassesCol], sortedClasses.ToAnnotations());
return infos;
}
// REVIEW: Figure out how to avoid having the column name in each slot name.
private ValueGetter<VBuffer<ReadOnlyMemory<char>>> CreateSlotNamesGetter(int numTopClasses, string suffix)
{
return
(ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
var editor = VBufferEditor.Create(ref dst, numTopClasses);
for (int i = 1; i <= numTopClasses; i++)
editor.Values[i - 1] = string.Format("#{0} {1}", i, suffix).AsMemory();
dst = editor.Commit();
};
}
private ValueGetter<VBuffer<ReadOnlyMemory<char>>> CreateKeyValueGetter()
{
return
(ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
var editor = VBufferEditor.Create(ref dst, _numClasses);
for (int i = 0; i < _numClasses; i++)
editor.Values[i] = _classNames[i];
dst = editor.Commit();
};
}
private void CheckInputColumnTypes(DataViewSchema schema)
{
Host.AssertNonEmpty(ScoreCol);
Host.AssertNonEmpty(LabelCol);
var scoreType = schema[ScoreIndex].Type as VectorDataViewType;
if (scoreType == null || scoreType.Size < 2 || scoreType.ItemType != NumberDataViewType.Single)
throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol, "Vector of two or more items of type Single", schema[ScoreIndex].Type.ToString());
var labelType = schema[LabelIndex].Type;
if (labelType != NumberDataViewType.Single && labelType.GetKeyCount() <= 0)
throw Host.ExceptSchemaMismatch(nameof(schema), "label", LabelCol, "Single or Key", labelType.ToString());
}
}
[BestFriend]
internal sealed class MulticlassClassificationMamlEvaluator : MamlEvaluatorBase
{
public class Arguments : ArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Output top-K accuracy.", ShortName = "topkacc")]
public int? OutputTopKAcc;
[Argument(ArgumentType.AtMostOnce, HelpText = "Output top-K classes.", ShortName = "topk")]
public int NumTopClassesToOutput = 3;
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of classes in confusion matrix.", ShortName = "nccf")]
public int NumClassesConfusionMatrix = 10;
[Argument(ArgumentType.AtMostOnce, HelpText = "Output per class statistics and confusion matrix.", ShortName = "opcs")]
public bool OutputPerClassStatistics = false;
}
private const string TopKAccuracyFormat = "Top-{0}-accuracy";
private readonly bool _outputPerClass;
private readonly int _numTopClasses;
private readonly int _numConfusionTableClasses;
private readonly int? _outputTopKAcc;
private readonly MulticlassClassificationEvaluator _evaluator;
private protected override IEvaluator Evaluator => _evaluator;
public MulticlassClassificationMamlEvaluator(IHostEnvironment env, Arguments args)
: base(args, env, AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification, "MultiClassMamlEvaluator")
{
Host.CheckValue(args, nameof(args));
// REVIEW: why do we need to insist on at least 2?
Host.CheckUserArg(2 <= args.NumTopClassesToOutput, nameof(args.NumTopClassesToOutput));
Host.CheckUserArg(2 <= args.NumClassesConfusionMatrix, nameof(args.NumClassesConfusionMatrix));
Host.CheckUserArg(args.OutputTopKAcc == null || args.OutputTopKAcc > 0, nameof(args.OutputTopKAcc));
Host.CheckUserArg(2 <= args.NumClassesConfusionMatrix, nameof(args.NumClassesConfusionMatrix));
_numTopClasses = args.NumTopClassesToOutput;
_outputPerClass = args.OutputPerClassStatistics;
_numConfusionTableClasses = args.NumClassesConfusionMatrix;
_outputTopKAcc = args.OutputTopKAcc;
var evalArgs = new MulticlassClassificationEvaluator.Arguments
{
OutputTopKAcc = _outputTopKAcc
};
_evaluator = new MulticlassClassificationEvaluator(Host, evalArgs);
}
private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
{
Host.AssertValue(metrics);
if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out IDataView fold))
throw ch.Except("No overall metrics found");
if (!metrics.TryGetValue(MetricKinds.ConfusionMatrix, out IDataView conf))
throw ch.Except("No confusion matrix found");
// Change the name of the Top-k-accuracies collection column & remove redundant old TopK output
if (_outputTopKAcc != null)
{
fold = ChangeAllTopKAccColumnName(fold);
fold = DropColumn(fold, MulticlassClassificationEvaluator.TopKAccuracy);
}
// Drop the per-class information.
if (!_outputPerClass)
fold = DropColumn(fold, MulticlassClassificationEvaluator.PerClassLogLoss);
var unweightedConf = MetricWriter.GetConfusionTableAsFormattedString(Host, conf, out string weightedConf, false, _numConfusionTableClasses);
var unweightedFold = MetricWriter.GetPerFoldResults(Host, fold, out string weightedFold);
ch.Assert(string.IsNullOrEmpty(weightedConf) == string.IsNullOrEmpty(weightedFold));
if (!string.IsNullOrEmpty(weightedConf))
{
ch.Info(weightedConf);
ch.Info(weightedFold);
}
ch.Info(unweightedConf);
ch.Info(unweightedFold);
}
private protected override IDataView CombineOverallMetricsCore(IDataView[] metrics)
{
var overallList = new List<IDataView>();
for (int i = 0; i < metrics.Length; i++)
{
var idv = metrics[i];
// Change the name of the Top-k-accuracies collection column & remove redundant old TopK output
if (_outputTopKAcc != null)
{
idv = ChangeAllTopKAccColumnName(idv);
idv = DropColumn(idv, MulticlassClassificationEvaluator.TopKAccuracy);
}
if (!_outputPerClass)
idv = DropColumn(idv, MulticlassClassificationEvaluator.PerClassLogLoss);
overallList.Add(idv);
}
var views = overallList.ToArray();
if (_outputPerClass)
{
EvaluateUtils.ReconcileSlotNames<double>(Host, views, MulticlassClassificationEvaluator.PerClassLogLoss, NumberDataViewType.Double,
def: double.NaN);
for (int i = 0; i < overallList.Count; i++)
{
var idv = views[i];
// Find the old per-class log-loss column and drop it.
for (int col = 0; col < idv.Schema.Count; col++)
{
if (idv.Schema[col].IsHidden &&
idv.Schema[col].Name.Equals(MulticlassClassificationEvaluator.PerClassLogLoss))
{
idv = new ChooseColumnsByIndexTransform(Host,
new ChooseColumnsByIndexTransform.Options() { Drop = true, Indices = new[] { col } }, idv);
break;
}
}
views[i] = idv;
}
}
return base.CombineOverallMetricsCore(views);
}
private IDataView ChangeTopKAccColumnName(IDataView input)
{
input = new ColumnCopyingTransformer(Host, (string.Format(TopKAccuracyFormat, _outputTopKAcc), MulticlassClassificationEvaluator.TopKAccuracy)).Transform(input);
return ColumnSelectingTransformer.CreateDrop(Host, input, MulticlassClassificationEvaluator.TopKAccuracy);
}
private IDataView ChangeAllTopKAccColumnName(IDataView input)
{
input = new ColumnCopyingTransformer(Host, (TopKAccuracyFormat, MulticlassClassificationEvaluator.AllTopKAccuracy)).Transform(input);
return ColumnSelectingTransformer.CreateDrop(Host, input, MulticlassClassificationEvaluator.AllTopKAccuracy);
}
private IDataView DropColumn(IDataView input, string columnToDrop)
{
if (input.Schema.TryGetColumnIndex(columnToDrop, out int ColInd))
{
input = ColumnSelectingTransformer.CreateDrop(Host, input, columnToDrop);
}
return input;
}
public override IEnumerable<MetricColumn> GetOverallMetricColumns()
{
yield return new MetricColumn("AccuracyMicro", MulticlassClassificationEvaluator.AccuracyMicro);
yield return new MetricColumn("AccuracyMacro", MulticlassClassificationEvaluator.AccuracyMacro);
yield return new MetricColumn("TopKAccuracy", string.Format(TopKAccuracyFormat, _outputTopKAcc));
if (_outputPerClass)
{
yield return new MetricColumn("LogLoss<class name>",
MulticlassClassificationEvaluator.PerClassLogLoss, MetricColumn.Objective.Minimize, isVector: true,
namePattern: new Regex(string.Format(@"^{0}(?<class>.+)", MulticlassClassificationEvaluator.LogLoss), RegexOptions.IgnoreCase));
}
yield return new MetricColumn("LogLoss", MulticlassClassificationEvaluator.LogLoss, MetricColumn.Objective.Minimize);
yield return new MetricColumn("LogLossReduction", MulticlassClassificationEvaluator.LogLossReduction);
yield return new MetricColumn("TopKAccuracyForAllK", MulticlassClassificationEvaluator.AllTopKAccuracy, isVector: true);
}
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column");
// Output the label column.
yield return schema.Label.Value.Name;
// Return the output columns.
yield return MulticlassPerInstanceEvaluator.Assigned;
yield return MulticlassPerInstanceEvaluator.LogLoss;
yield return MulticlassPerInstanceEvaluator.SortedScores;
yield return MulticlassPerInstanceEvaluator.SortedClasses;
}
// Multi-class evaluator adds four per-instance columns: "Assigned", "Top scores", "Top classes" and "Log-loss".
private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
{
// If the label column is a key without text key values, convert it to double, just for saving the per-instance
// text file, since if there are different key counts the columns cannot be appended.
string labelName = schema.Label.Value.Name;
if (!perInst.Schema.TryGetColumnIndex(labelName, out int labelColIndex))
throw Host.ExceptSchemaMismatch(nameof(schema), "label", labelName);
var labelCol = perInst.Schema[labelColIndex];
var labelType = labelCol.Type;
if (labelType is KeyDataViewType && (!labelCol.HasKeyValues() || labelType.RawType != typeof(uint)))
{
perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, labelName,
labelName, labelCol.Type, NumberDataViewType.Double,
(in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1);
}
var perInstSchema = perInst.Schema;
if (perInstSchema.TryGetColumnIndex(MulticlassPerInstanceEvaluator.SortedClasses, out int sortedClassesIndex))
{
var type = perInstSchema[sortedClassesIndex].Type;
// Wrap with a DropSlots transform to pick only the first _numTopClasses slots.
if (_numTopClasses < type.GetVectorSize())
perInst = new SlotsDroppingTransformer(Host, MulticlassPerInstanceEvaluator.SortedClasses, min: _numTopClasses).Transform(perInst);
}
// Wrap with a DropSlots transform to pick only the first _numTopClasses slots.
if (perInst.Schema.TryGetColumnIndex(MulticlassPerInstanceEvaluator.SortedScores, out int sortedScoresIndex))
{
var type = perInst.Schema[sortedScoresIndex].Type;
if (_numTopClasses < type.GetVectorSize())
perInst = new SlotsDroppingTransformer(Host, MulticlassPerInstanceEvaluator.SortedScores, min: _numTopClasses).Transform(perInst);
}
return perInst;
}
}
internal static partial class Evaluate
{
[TlcModule.EntryPoint(Name = "Models.ClassificationEvaluator", Desc = "Evaluates a multi class classification scored dataset.")]
public static CommonOutputs.ClassificationEvaluateOutput Multiclass(IHostEnvironment env, MulticlassClassificationMamlEvaluator.Arguments input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("EvaluateMultiClass");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
MatchColumns(host, input, out string label, out string weight, out string name);
IMamlEvaluator evaluator = new MulticlassClassificationMamlEvaluator(host, input);
var data = new RoleMappedData(input.Data, label, null, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
var overallMetrics = ExtractOverallMetrics(host, metrics, evaluator);
var perInstanceMetrics = evaluator.GetPerInstanceMetrics(data);
var confusionMatrix = ExtractConfusionMatrix(host, metrics);
return new CommonOutputs.ClassificationEvaluateOutput()
{
Warnings = warnings,
OverallMetrics = overallMetrics,
PerInstanceMetrics = perInstanceMetrics,
ConfusionMatrix = confusionMatrix
};
}
}
}
|