|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(typeof(ClusteringEvaluator), typeof(ClusteringEvaluator), typeof(ClusteringEvaluator.Arguments), typeof(SignatureEvaluator),
"Clustering Evaluator", ClusteringEvaluator.LoadName, "Clustering")]
[assembly: LoadableClass(typeof(ClusteringMamlEvaluator), typeof(ClusteringMamlEvaluator), typeof(ClusteringMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator),
"Clustering Evaluator", ClusteringEvaluator.LoadName, "Clustering")]
// This is for deserialization of the per-instance transform.
[assembly: LoadableClass(typeof(ClusteringPerInstanceEvaluator), null, typeof(SignatureLoadRowMapper),
"", ClusteringPerInstanceEvaluator.LoaderSignature)]
namespace Microsoft.ML.Data
{
using Conditional = System.Diagnostics.ConditionalAttribute;
[BestFriend]
internal sealed class ClusteringEvaluator : RowToRowEvaluatorBase<ClusteringEvaluator.Aggregator>
{
public sealed class Arguments
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Calculate DBI? (time-consuming unsupervised metric)",
ShortName = "dbi")]
public bool CalculateDbi = false;
}
public const string LoadName = "ClusteringEvaluator";
public const string Nmi = "NMI";
public const string AvgMinScore = "AvgMinScore";
public const string Dbi = "DBI";
private readonly bool _calculateDbi;
public ClusteringEvaluator(IHostEnvironment env, Arguments args)
: base(env, LoadName)
{
Host.AssertValue(args, "args");
_calculateDbi = args.CalculateDbi;
}
/// <summary>
/// Evaluates scored clustering data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
/// <param name="label">The name of the optional label column in <paramref name="data"/>.</param>
/// <param name="features">The name of the optional feature column in <paramref name="data"/>.</param>
/// <returns>The evaluation results.</returns>
public ClusteringMetrics Evaluate(IDataView data, string score, string label = null, string features = null)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(score, nameof(score));
var roles = new List<KeyValuePair<RoleMappedSchema.ColumnRole, string>>();
roles.Add(RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score));
if (label != null)
roles.Add(RoleMappedSchema.ColumnRole.Label.Bind(label));
if (features != null)
roles.Add(RoleMappedSchema.ColumnRole.Feature.Bind(features));
var rolesMappedData = new RoleMappedData(data, opt: false, roles.ToArray());
var resultDict = ((IEvaluator)this).Evaluate(rolesMappedData);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];
ClusteringMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new ClusteringMetrics(Host, cursor, _calculateDbi);
moved = cursor.MoveNext();
Host.Assert(!moved);
}
return result;
}
private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
DataViewType type = schema.Label?.Type;
if (type != null && type != NumberDataViewType.Single && !(type is KeyDataViewType keyType && keyType.Count > 0))
{
throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name,
"Single or Key", type.ToString());
}
var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
type = score.Type;
if (!(type is VectorDataViewType vectorType) || !vectorType.IsKnownSize || vectorType.ItemType != NumberDataViewType.Single)
throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "known-size vector of Single", type.ToString());
}
private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
{
if (_calculateDbi)
{
Host.Assert(schema.Feature.HasValue);
var t = schema.Feature.Value.Type;
if (!(t is VectorDataViewType vectorType) || !vectorType.IsKnownSize || vectorType.ItemType != NumberDataViewType.Single)
{
throw Host.ExceptSchemaMismatch(nameof(schema), "features", schema.Feature.Value.Name,
"R4 vector of known size", t.ToString());
}
}
}
private protected override Func<int, bool> GetActiveColsCore(RoleMappedSchema schema)
{
var pred = base.GetActiveColsCore(schema);
// We also need the features column for dbi calculation.
Host.Assert(!_calculateDbi || schema.Feature != null);
return i => _calculateDbi && i == schema.Feature.Value.Index || pred(i);
}
private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
Host.AssertValue(schema);
Host.Assert(!_calculateDbi || schema.Feature?.Type.IsKnownSizeVector() == true);
var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
var scoreType = score.Type as VectorDataViewType;
Host.Assert(scoreType != null && scoreType.Size > 0);
int numClusters = scoreType.Size;
return new Aggregator(Host, schema.Feature, numClusters, _calculateDbi, schema.Weight != null, stratName);
}
private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema)
{
var scoreInfo = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
int numClusters = scoreInfo.Type.GetVectorSize();
return new ClusteringPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, numClusters);
}
public override IEnumerable<MetricColumn> GetOverallMetricColumns()
{
yield return new MetricColumn("NMI", Nmi);
yield return new MetricColumn("AvgMinScore", AvgMinScore, MetricColumn.Objective.Minimize);
yield return new MetricColumn("DBI", Dbi, MetricColumn.Objective.Minimize);
}
private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action<uint, ReadOnlyMemory<char>, Aggregator> addAgg, out Func<Dictionary<string, IDataView>> consolidate)
{
var stratCol = new List<uint>();
var stratVal = new List<ReadOnlyMemory<char>>();
var isWeighted = new List<bool>();
var nmi = new List<Double>();
var avgMinScores = new List<Double>();
var dbi = new List<Double>();
bool hasStrats = Utils.Size(dictionaries) > 0;
bool hasWeight = aggregator.Weighted;
addAgg =
(stratColKey, stratColVal, agg) =>
{
Host.Check(agg.Weighted == hasWeight, "All aggregators must either be weighted or unweighted");
Host.Check(agg.UnweightedCounters.CalculateDbi == aggregator.UnweightedCounters.CalculateDbi,
"All aggregators must either compute DBI or not compute DBI");
stratCol.Add(stratColKey);
stratVal.Add(stratColVal);
isWeighted.Add(false);
nmi.Add(agg.UnweightedCounters.Nmi);
avgMinScores.Add(agg.UnweightedCounters.AvgMinScores);
if (agg.UnweightedCounters.CalculateDbi)
dbi.Add(agg.UnweightedCounters.Dbi);
if (agg.Weighted)
{
stratCol.Add(stratColKey);
stratVal.Add(stratColVal);
isWeighted.Add(true);
nmi.Add(agg.WeightedCounters.Nmi);
avgMinScores.Add(agg.WeightedCounters.AvgMinScores);
if (agg.WeightedCounters.CalculateDbi)
dbi.Add(agg.WeightedCounters.Dbi);
}
};
consolidate =
() =>
{
var overallDvBldr = new ArrayDataViewBuilder(Host);
if (hasStrats)
{
overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, stratCol.ToArray());
overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, stratVal.ToArray());
}
if (hasWeight)
overallDvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BooleanDataViewType.Instance, isWeighted.ToArray());
overallDvBldr.AddColumn(Nmi, NumberDataViewType.Double, nmi.ToArray());
overallDvBldr.AddColumn(AvgMinScore, NumberDataViewType.Double, avgMinScores.ToArray());
if (aggregator.UnweightedCounters.CalculateDbi)
overallDvBldr.AddColumn(Dbi, NumberDataViewType.Double, dbi.ToArray());
var result = new Dictionary<string, IDataView>
{
{ MetricKinds.OverallMetrics, overallDvBldr.GetDataView() }
};
return result;
};
}
public sealed class Aggregator : AggregatorBase
{
public sealed class Counters
{
private Double _numInstances;
private Double _sumMinScores;
// Since we know in advance how many clusters we have, this can be an array.
private readonly Double[] _numInstancesOfClstr;
// We don't know how many classes there will be, so this is a list, that grows when we see new classes.
private readonly List<Double> _numInstancesOfClass;
private readonly List<Double[]> _confusionMatrix;
// These are used for DBI calculation.
private readonly VBuffer<Single>[] _clusterCentroids;
private readonly Double[] _distancesToCentroids;
private readonly int _numClusters;
public readonly bool CalculateDbi;
public Double Nmi
{
get
{
Double nmi = Double.NaN;
if (_confusionMatrix.Count > 1)
{
nmi = 0;
Double entropy = 0;
for (int i = 0; i < _confusionMatrix.Count; i++)
{
var px = _numInstancesOfClass[i] / _numInstances;
if (px <= 0)
continue;
for (int j = 0; j < _confusionMatrix[i].Length; j++)
{
var pxy = _confusionMatrix[i][j] / _numInstances;
var py = _numInstancesOfClstr[j] / _numInstances;
if (pxy <= 0 || py <= 0)
continue;
nmi += pxy * Math.Log(pxy / (px * py));
}
entropy += -px * Math.Log(px);
}
nmi /= entropy; // entropy can't be zero, because there's at least 2 instances in at least 2 classes
}
return nmi;
}
}
public Double AvgMinScores { get { return _sumMinScores / _numInstances; } }
public Double Dbi
{
get
{
if (!CalculateDbi)
return Double.NaN;
Double dbi = 0;
var clusterCount = _distancesToCentroids.Length;
for (int i = 0; i < clusterCount; i++)
_distancesToCentroids[i] /= _numInstancesOfClstr[i];
for (int i = 0; i < clusterCount; i++)
{
Double maxi = 0;
if (_numInstancesOfClstr[i] == 0)
continue;
var centroidI = _clusterCentroids[i];
for (int j = 0; j < clusterCount; j++)
{
if (i == j)
continue;
if (_numInstancesOfClstr[j] == 0)
continue;
var centroidJ = _clusterCentroids[j];
Double num = _distancesToCentroids[i] + _distancesToCentroids[j];
Single denom = VectorUtils.Distance(in centroidI, in centroidJ);
maxi = Math.Max(maxi, num / denom);
}
dbi += maxi;
}
dbi /= clusterCount;
return dbi;
}
}
public Counters(int numClusters, bool calculateDbi, DataViewSchema.Column? features)
{
_numClusters = numClusters;
CalculateDbi = calculateDbi;
_numInstancesOfClstr = new Double[_numClusters];
_numInstancesOfClass = new List<Double>();
_confusionMatrix = new List<Double[]>();
if (CalculateDbi)
{
Contracts.Assert(features.HasValue);
_clusterCentroids = new VBuffer<Single>[_numClusters];
for (int i = 0; i < _numClusters; i++)
_clusterCentroids[i] = VBufferUtils.CreateEmpty<Single>(features.Value.Type.GetVectorSize());
_distancesToCentroids = new Double[_numClusters];
}
}
public void UpdateFirstPass(int intLabel, Single[] scores, Single weight, int[] indices)
{
Contracts.Assert(Utils.Size(scores) == _numClusters);
Contracts.Assert(Utils.Size(indices) == _numClusters);
int assigned = indices[0];
_numInstances += weight;
_sumMinScores += weight * scores[indices[0]];
while (_numInstancesOfClass.Count <= intLabel)
_numInstancesOfClass.Add(0);
_numInstancesOfClass[intLabel] += weight;
_numInstancesOfClstr[assigned] += weight;
while (_confusionMatrix.Count <= intLabel)
_confusionMatrix.Add(new Double[scores.Length]);
_confusionMatrix[intLabel][assigned] += weight;
}
public void InitializeSecondPass(VBuffer<Single>[] clusterCentroids)
{
for (int i = 0; i < clusterCentroids.Length; i++)
{
clusterCentroids[i].CopyTo(ref _clusterCentroids[i]);
VectorUtils.ScaleBy(ref _clusterCentroids[i], (Single)(1.0 / _numInstancesOfClstr[i]));
}
}
public void UpdateSecondPass(in VBuffer<Single> features, int[] indices)
{
int assigned = indices[0];
var distance = VectorUtils.Distance(in _clusterCentroids[assigned], in features);
_distancesToCentroids[assigned] += distance;
}
}
// The getters are initialized in InitializeNextPass(), when the new DataViewRowCursor is available.
private ValueGetter<Single> _labelGetter;
private ValueGetter<VBuffer<Single>> _scoreGetter;
private ValueGetter<Single> _weightGetter;
private ValueGetter<VBuffer<Single>> _featGetter;
// Buffers that hold the features and the scores of the current row.
private VBuffer<Single> _scores;
private readonly Single[] _scoresArr;
private readonly int[] _indicesArr;
private VBuffer<Single> _features;
// This is used for DBI calculation.
private readonly VBuffer<Single>[] _clusterCentroids;
public readonly Counters UnweightedCounters;
public readonly Counters WeightedCounters;
public readonly bool Weighted;
private readonly bool _calculateDbi;
internal Aggregator(IHostEnvironment env, DataViewSchema.Column? features, int scoreVectorSize, bool calculateDbi, bool weighted, string stratName)
: base(env, stratName)
{
_calculateDbi = calculateDbi;
_scoresArr = new float[scoreVectorSize];
_indicesArr = new int[scoreVectorSize];
UnweightedCounters = new Counters(scoreVectorSize, _calculateDbi, features);
Weighted = weighted;
WeightedCounters = Weighted ? new Counters(scoreVectorSize, _calculateDbi, features) : null;
if (_calculateDbi)
{
Host.Assert(features.HasValue);
_clusterCentroids = new VBuffer<Single>[scoreVectorSize];
for (int i = 0; i < scoreVectorSize; i++)
_clusterCentroids[i] = VBufferUtils.CreateEmpty<Single>(features.Value.Type.GetVectorSize());
}
}
private void ProcessRowFirstPass()
{
AssertValid(assertGetters: true);
Single label = 0;
_labelGetter(ref label);
if (Single.IsNaN(label))
{
NumUnlabeledInstances++;
label = 0;
}
var intLabel = (int)label;
if (intLabel != label || intLabel < 0)
throw Host.Except("Invalid label: {0}", label);
_scoreGetter(ref _scores);
Host.Check(_scores.Length == _scoresArr.Length);
if (VBufferUtils.HasNaNs(in _scores) || VBufferUtils.HasNonFinite(in _scores))
{
NumBadScores++;
return;
}
_scores.CopyTo(_scoresArr);
Single weight = 1;
if (_weightGetter != null)
{
_weightGetter(ref weight);
if (!FloatUtils.IsFinite(weight))
{
NumBadWeights++;
weight = 1;
}
}
int j = 0;
foreach (var index in Enumerable.Range(0, _scoresArr.Length).OrderBy(i => _scoresArr[i]))
_indicesArr[j++] = index;
UnweightedCounters.UpdateFirstPass(intLabel, _scoresArr, 1, _indicesArr);
if (WeightedCounters != null)
WeightedCounters.UpdateFirstPass(intLabel, _scoresArr, weight, _indicesArr);
if (_clusterCentroids != null)
{
_featGetter(ref _features);
VectorUtils.Add(in _features, ref _clusterCentroids[_indicesArr[0]]);
}
}
private void ProcessRowSecondPass()
{
AssertValid(assertGetters: true);
_featGetter(ref _features);
_scoreGetter(ref _scores);
Host.Check(_scores.Length == _scoresArr.Length);
if (VBufferUtils.HasNaNs(in _scores) || VBufferUtils.HasNonFinite(in _scores))
return;
_scores.CopyTo(_scoresArr);
int j = 0;
foreach (var index in Enumerable.Range(0, _scoresArr.Length).OrderBy(i => _scoresArr[i]))
_indicesArr[j++] = index;
UnweightedCounters.UpdateSecondPass(in _features, _indicesArr);
if (WeightedCounters != null)
WeightedCounters.UpdateSecondPass(in _features, _indicesArr);
}
internal override void InitializeNextPass(DataViewRow row, RoleMappedSchema schema)
{
AssertValid(assertGetters: false);
Host.AssertValue(row);
Host.AssertValue(schema);
if (_calculateDbi)
{
Host.Assert(schema.Feature.HasValue);
_featGetter = row.GetGetter<VBuffer<Single>>(schema.Feature.Value);
}
var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
Host.Assert(score.Type.GetVectorSize() == _scoresArr.Length);
_scoreGetter = row.GetGetter<VBuffer<Single>>(score);
if (PassNum == 0)
{
if (schema.Label.HasValue)
_labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
else
_labelGetter = (ref Single value) => value = Single.NaN;
if (schema.Weight.HasValue)
_weightGetter = row.GetGetter<Single>(schema.Weight.Value);
}
else
{
Host.Assert(PassNum == 1 && _calculateDbi);
UnweightedCounters.InitializeSecondPass(_clusterCentroids);
if (WeightedCounters != null)
WeightedCounters.InitializeSecondPass(_clusterCentroids);
}
AssertValid(assertGetters: true);
}
public override void ProcessRow()
{
if (PassNum == 0)
ProcessRowFirstPass();
else
ProcessRowSecondPass();
}
public override bool IsActive()
{
return _calculateDbi && PassNum < 2 || PassNum < 1;
}
protected override void FinishPassCore()
{
AssertValid(assertGetters: false);
}
[Conditional("DEBUG")]
private void AssertValid(bool assertGetters)
{
Host.Assert(IsActive());
if (assertGetters)
{
if (PassNum == 0)
{
Host.AssertValue(_labelGetter);
Host.AssertValue(_scoreGetter);
Host.AssertValueOrNull(_weightGetter);
Host.Assert(!_calculateDbi || _featGetter != null);
}
else
{
Host.Assert(PassNum == 1 && _calculateDbi);
Host.AssertValue(_featGetter);
Host.AssertValue(_scoreGetter);
}
}
}
}
}
internal sealed class ClusteringPerInstanceEvaluator : PerInstanceEvaluatorBase
{
public const string LoaderSignature = "ClusteringPerInstance";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "CLSTRINS",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(ClusteringPerInstanceEvaluator).Assembly.FullName);
}
private const int ClusterIdCol = 0;
private const int SortedClusterCol = 1;
private const int SortedClusterScoreCol = 2;
public const string ClusterId = "ClusterId";
public const string SortedClusters = "SortedClusters";
public const string SortedClusterScores = "SortedScores";
private readonly int _numClusters;
private readonly DataViewType[] _types;
public ClusteringPerInstanceEvaluator(IHostEnvironment env, DataViewSchema schema, string scoreCol, int numClusters)
: base(env, schema, scoreCol, null)
{
CheckInputColumnTypes(schema);
_numClusters = numClusters;
_types = new DataViewType[3];
var key = new KeyDataViewType(typeof(uint), _numClusters);
_types[ClusterIdCol] = key;
_types[SortedClusterCol] = new VectorDataViewType(key, _numClusters);
_types[SortedClusterScoreCol] = new VectorDataViewType(NumberDataViewType.Single, _numClusters);
}
private ClusteringPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
: base(env, ctx, schema)
{
CheckInputColumnTypes(schema);
// *** Binary format **
// base
// int: number of clusters
_numClusters = ctx.Reader.ReadInt32();
Host.CheckDecode(_numClusters > 0);
_types = new DataViewType[3];
var key = new KeyDataViewType(typeof(uint), _numClusters);
_types[ClusterIdCol] = key;
_types[SortedClusterCol] = new VectorDataViewType(key, _numClusters);
_types[SortedClusterScoreCol] = new VectorDataViewType(NumberDataViewType.Single, _numClusters);
}
public static ClusteringPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new ClusteringPerInstanceEvaluator(env, ctx, schema);
}
private protected override void SaveModel(ModelSaveContext ctx)
{
// *** Binary format **
// base
// int: number of clusters
base.SaveModel(ctx);
Host.Assert(_numClusters > 0);
ctx.Writer.Write(_numClusters);
}
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
return
col =>
col == ScoreIndex &&
(activeOutput(ClusterIdCol) || activeOutput(SortedClusterCol) || activeOutput(SortedClusterScoreCol));
}
private protected override Delegate[] CreateGettersCore(DataViewRow input, Func<int, bool> activeCols, out Action disposer)
{
disposer = null;
var getters = new Delegate[3];
if (!activeCols(ClusterIdCol) && !activeCols(SortedClusterCol) && !activeCols(SortedClusterScoreCol))
return getters;
long cachedPosition = -1;
VBuffer<Single> scores = default(VBuffer<Single>);
var scoresArr = new Single[_numClusters];
int[] sortedIndices = new int[_numClusters];
var scoreGetter = input.GetGetter<VBuffer<Single>>(input.Schema[ScoreIndex]);
Action updateCacheIfNeeded =
() =>
{
if (cachedPosition != input.Position)
{
scoreGetter(ref scores);
scores.CopyTo(scoresArr);
int j = 0;
foreach (var index in Enumerable.Range(0, scoresArr.Length).OrderBy(i => scoresArr[i]))
sortedIndices[j++] = index;
cachedPosition = input.Position;
}
};
if (activeCols(ClusterIdCol))
{
ValueGetter<uint> assignedFn =
(ref uint dst) =>
{
updateCacheIfNeeded();
dst = (uint)sortedIndices[0] + 1;
};
getters[ClusterIdCol] = assignedFn;
}
if (activeCols(SortedClusterScoreCol))
{
ValueGetter<VBuffer<Single>> topKScoresFn =
(ref VBuffer<Single> dst) =>
{
updateCacheIfNeeded();
var editor = VBufferEditor.Create(ref dst, _numClusters);
for (int i = 0; i < _numClusters; i++)
editor.Values[i] = scores.GetItemOrDefault(sortedIndices[i]);
dst = editor.Commit();
};
getters[SortedClusterScoreCol] = topKScoresFn;
}
if (activeCols(SortedClusterCol))
{
ValueGetter<VBuffer<uint>> topKClassesFn =
(ref VBuffer<uint> dst) =>
{
updateCacheIfNeeded();
var editor = VBufferEditor.Create(ref dst, _numClusters);
for (int i = 0; i < _numClusters; i++)
editor.Values[i] = (uint)sortedIndices[i] + 1;
dst = editor.Commit();
};
getters[SortedClusterCol] = topKClassesFn;
}
return getters;
}
private protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var infos = new DataViewSchema.DetachedColumn[3];
infos[ClusterIdCol] = new DataViewSchema.DetachedColumn(ClusterId, _types[ClusterIdCol], null);
var slotNamesType = new VectorDataViewType(TextDataViewType.Instance, _numClusters);
var sortedClusters = new DataViewSchema.Annotations.Builder();
int vectorSize = slotNamesType.GetVectorSize();
sortedClusters.AddSlotNames(vectorSize, CreateSlotNamesGetter(_numClusters, "Cluster"));
var builder = new DataViewSchema.Annotations.Builder();
builder.AddSlotNames(vectorSize, CreateSlotNamesGetter(_numClusters, "Score"));
infos[SortedClusterCol] = new DataViewSchema.DetachedColumn(SortedClusters, _types[SortedClusterCol], sortedClusters.ToAnnotations());
infos[SortedClusterScoreCol] = new DataViewSchema.DetachedColumn(SortedClusterScores, _types[SortedClusterScoreCol], builder.ToAnnotations());
return infos;
}
// REVIEW: Figure out how to avoid having the column name in each slot name.
private ValueGetter<VBuffer<ReadOnlyMemory<char>>> CreateSlotNamesGetter(int numTopClusters, string suffix)
{
return
(ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
var editor = VBufferEditor.Create(ref dst, numTopClusters);
for (int i = 1; i <= numTopClusters; i++)
editor.Values[i - 1] = $"#{i} {suffix}".AsMemory();
dst = editor.Commit();
};
}
private void CheckInputColumnTypes(DataViewSchema schema)
{
Host.AssertNonEmpty(ScoreCol);
var type = schema[(int)ScoreIndex].Type;
if (!(type is VectorDataViewType vectorType) || !vectorType.IsKnownSize || vectorType.ItemType != NumberDataViewType.Single)
throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol, "known-size vector of Single", type.ToString());
}
}
[BestFriend]
internal sealed class ClusteringMamlEvaluator : MamlEvaluatorBase
{
public class Arguments : ArgumentsBase
{
// REVIEW: Remove BDI centroid measure which is sensible to apply in the k-means case only and remove features argument
[Argument(ArgumentType.AtMostOnce, HelpText = "Features column name", ShortName = "feat")]
public string FeatureColumn;
[Argument(ArgumentType.AtMostOnce, HelpText = "Calculate DBI? (time-consuming unsupervised metric)", ShortName = "dbi")]
public bool CalculateDbi = false;
[Argument(ArgumentType.AtMostOnce, HelpText = "Output top K clusters", ShortName = "topk")]
public int NumTopClustersToOutput = 3;
}
private readonly ClusteringEvaluator _evaluator;
private readonly int _numTopClusters;
private readonly string _featureCol;
private readonly bool _calculateDbi;
private protected override IEvaluator Evaluator => _evaluator;
public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args)
: base(args, env, AnnotationUtils.Const.ScoreColumnKind.Clustering, "ClusteringMamlEvaluator")
{
Host.CheckValue(args, nameof(args));
Host.CheckUserArg(1 <= args.NumTopClustersToOutput, nameof(args.NumTopClustersToOutput));
_numTopClusters = args.NumTopClustersToOutput;
_featureCol = args.FeatureColumn;
_calculateDbi = args.CalculateDbi;
var evalArgs = new ClusteringEvaluator.Arguments
{
CalculateDbi = _calculateDbi
};
_evaluator = new ClusteringEvaluator(Host, evalArgs);
}
private protected override IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRolesCore(RoleMappedSchema schema)
{
foreach (var col in base.GetInputColumnRolesCore(schema))
{
if (!col.Key.Equals(RoleMappedSchema.ColumnRole.Label))
yield return col;
else if (schema.Schema.TryGetColumnIndex(col.Value, out int labelIndex))
yield return col;
}
if (_calculateDbi)
{
string feat = EvaluateUtils.GetColName(_featureCol, schema.Feature, DefaultColumnNames.Features);
if (!schema.Schema.TryGetColumnIndex(feat, out int featCol))
throw Host.ExceptUserArg(nameof(Arguments.FeatureColumn), "Features column '{0}' not found", feat);
yield return RoleMappedSchema.ColumnRole.Feature.Bind(feat);
}
}
// Clustering evaluator adds three per-instance columns: "ClusterId", "Top clusters" and "Top cluster scores".
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
// Output the label column if it exists.
if (schema.Label.HasValue)
yield return schema.Label.Value.Name;
// Return the output columns.
yield return ClusteringPerInstanceEvaluator.ClusterId;
yield return ClusteringPerInstanceEvaluator.SortedClusters;
yield return ClusteringPerInstanceEvaluator.SortedClusterScores;
}
private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
{
// Wrap with a DropSlots transform to pick only the first _numTopClusters slots.
if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusters, out int index))
{
var type = perInst.Schema[index].Type;
if (_numTopClusters < type.GetVectorSize())
perInst = new SlotsDroppingTransformer(Host, ClusteringPerInstanceEvaluator.SortedClusters, min: _numTopClusters).Transform(perInst);
}
if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusterScores, out index))
{
var type = perInst.Schema[index].Type;
if (_numTopClusters < type.GetVectorSize())
perInst = new SlotsDroppingTransformer(Host, ClusteringPerInstanceEvaluator.SortedClusterScores, min: _numTopClusters).Transform(perInst);
}
return perInst;
}
}
internal static partial class Evaluate
{
[TlcModule.EntryPoint(Name = "Models.ClusterEvaluator", Desc = "Evaluates a clustering scored dataset.")]
public static CommonOutputs.CommonEvaluateOutput Clustering(IHostEnvironment env, ClusteringMamlEvaluator.Arguments input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("EvaluateClustering");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
MatchColumns(host, input, out string label, out string weight, out string name);
var schema = input.Data.Schema;
string features = TrainUtils.MatchNameOrDefaultOrNull(host, schema,
nameof(ClusteringMamlEvaluator.Arguments.FeatureColumn),
input.FeatureColumn, DefaultColumnNames.Features);
IMamlEvaluator evaluator = new ClusteringMamlEvaluator(host, input);
var data = new RoleMappedData(input.Data, label, features, null, weight, name);
var metrics = evaluator.Evaluate(data);
var warnings = ExtractWarnings(host, metrics);
var overallMetrics = ExtractOverallMetrics(host, metrics, evaluator);
var perInstanceMetrics = evaluator.GetPerInstanceMetrics(data);
return new CommonOutputs.CommonEvaluateOutput()
{
Warnings = warnings,
OverallMetrics = overallMetrics,
PerInstanceMetrics = perInstanceMetrics
};
}
}
}
|