|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(MutualInformationFeatureSelectingEstimator.Summary, typeof(IDataTransform), typeof(MutualInformationFeatureSelectingEstimator), typeof(MutualInformationFeatureSelectingEstimator.Options), typeof(SignatureDataTransform),
MutualInformationFeatureSelectingEstimator.UserName, "MutualInformationFeatureSelection", "MutualInformationFeatureSelectionTransform", MutualInformationFeatureSelectingEstimator.ShortName)]
namespace Microsoft.ML.Transforms
{
/// <summary>
/// Selects the top k slots across all specified columns ordered by their mutual information with the label column
/// (what you can learn about the label by observing the value of the specified column).
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
///
/// ### Estimator Characteristics
/// | | |
/// | -- | -- |
/// | Does this estimator need to look at the data to train its parameters? | Yes |
/// | Input column data type | Vector or scalar of numeric, [text](xref:Microsoft.ML.Data.TextDataViewType) or [key](xref:Microsoft.ML.Data.KeyDataViewType) data types|
/// | Output column data type | Same as the input column|
/// | Exportable to ONNX | Yes |
///
/// Formally, the mutual information can be written as:
///
/// $\text{MI}(X,Y) = E_{x,y}[\log(P(x,y)) - \log(P(x)) - \log(P(y))]$ where $x$ and $y$ are observations of random variables $X$ and $Y$.
///
/// where the expectation E is taken over the joint distribution of X and Y.
/// Here P(x, y) is the joint probability density function of X and Y, P(x) and P(y) are the marginal probability density functions of X and Y respectively.
/// In general, a higher mutual information between the dependent variable(or label) and an independent variable(or feature) means
/// that the label has higher mutual dependence over that feature.
/// It keeps the top slots in output features with the largest mutual information with the label.
///
/// For example, for the following Features and Label column, if we specify that we want the top 2 slots(vector elements) that have the higher correlation
/// with the label column, the output of applying this Estimator would keep the first and the third slots only, because their values
/// are more correlated with the values in the Label column.
///
/// | Label | Features |
/// | -- | -- |
/// |True |4,6,0 |
/// |False|0,7,5 |
/// |True |4,7,0 |
/// |False|0,7,0 |
///
/// This is how the dataset above would look, after fitting the estimator, and transforming the data with the resulting transformer:
///
/// | Label | Features |
/// | -- | -- |
/// |True |4,0 |
/// |False|0,5 |
/// |True |4,0 |
/// |False|0,0 |
///
/// Check the See Also section for links to usage examples.
/// ]]></format>
/// </remarks>
/// <seealso cref="FeatureSelectionCatalog.SelectFeaturesBasedOnMutualInformation(TransformsCatalog.FeatureSelectionTransforms, InputOutputColumnPair[], string, int, int)"/>
/// <seealso cref="FeatureSelectionCatalog.SelectFeaturesBasedOnMutualInformation(TransformsCatalog.FeatureSelectionTransforms, string, string, string, int, int)"/>
public sealed class MutualInformationFeatureSelectingEstimator : IEstimator<ITransformer>
{
internal const string Summary =
"Selects the top k slots across all specified columns ordered by their mutual information with the label column.";
internal const string UserName = "Mutual Information Feature Selection Transform";
internal const string ShortName = "MIFeatureSelection";
internal static string RegistrationName = "MutualInformationFeatureSelectionTransform";
[BestFriend]
internal static class Defaults
{
public const string LabelColumnName = DefaultColumnNames.Label;
public const int SlotsInOutput = 1000;
public const int NumBins = 256;
}
internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to use for feature selection", Name = "Column", ShortName = "col", SortOrder = 1)]
public string[] Columns;
[Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for labels", ShortName = "lab",
SortOrder = 4, Purpose = SpecialPurpose.ColumnName)]
public string LabelColumnName = Defaults.LabelColumnName;
[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of slots to preserve in output", ShortName = "topk,numSlotsToKeep",
SortOrder = 1)]
public int SlotsInOutput = Defaults.SlotsInOutput;
[Argument(ArgumentType.AtMostOnce, HelpText = "Max number of bins for R4/R8 columns, power of 2 recommended",
ShortName = "bins")]
public int NumBins = Defaults.NumBins;
}
private readonly IHost _host;
private readonly (string outputColumnName, string inputColumnName)[] _columns;
private readonly string _labelColumnName;
private readonly int _slotsInOutput;
private readonly int _numBins;
/// <include file='doc.xml' path='doc/members/member[@name="MutualInformationFeatureSelection"]/*' />
/// <param name="env">The environment to use.</param>
/// <param name="labelColumnName">Name of the column to use for labels.</param>
/// <param name="slotsInOutput">The maximum number of slots to preserve in the output. The number of slots to preserve is taken across all input columns.</param>
/// <param name="numberOfBins">Max number of bins used to approximate mutual information between each input column and the label column. Power of 2 recommended.</param>
/// <param name="columns">Specifies the names of the input columns for the transformation, and their respective output column names.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[MutualInformationFeatureSelectingEstimator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs?range=1-4,10-121)]
/// ]]>
/// </format>
/// </example>
internal MutualInformationFeatureSelectingEstimator(IHostEnvironment env,
string labelColumnName = Defaults.LabelColumnName,
int slotsInOutput = Defaults.SlotsInOutput,
int numberOfBins = Defaults.NumBins,
params (string outputColumnName, string inputColumnName)[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(RegistrationName);
_host.CheckUserArg(Utils.Size(columns) > 0, nameof(columns));
_host.CheckUserArg(slotsInOutput > 0, nameof(slotsInOutput));
_host.CheckNonWhiteSpace(labelColumnName, nameof(labelColumnName));
_host.Check(numberOfBins > 1, "numBins must be greater than 1.");
_columns = columns;
_labelColumnName = labelColumnName;
_slotsInOutput = slotsInOutput;
_numBins = numberOfBins;
}
/// <include file='doc.xml' path='doc/members/member[@name="MutualInformationFeatureSelection"]/*' />
/// <param name="env">The environment to use.</param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
/// <param name="labelColumnName">Name of the column to use for labels.</param>
/// <param name="slotsInOutput">The maximum number of slots to preserve in the output. The number of slots to preserve is taken across all input columns.</param>
/// <param name="numBins">Max number of bins used to approximate mutual information between each input column and the label column. Power of 2 recommended.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[MutualInformationFeatureSelectingEstimator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs?range=1-4,10-121)]
/// ]]>
/// </format>
/// </example>
internal MutualInformationFeatureSelectingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
string labelColumnName = Defaults.LabelColumnName, int slotsInOutput = Defaults.SlotsInOutput, int numBins = Defaults.NumBins)
: this(env, labelColumnName, slotsInOutput, numBins, (outputColumnName, inputColumnName ?? outputColumnName))
{
}
/// <summary>
/// Trains and returns a <see cref="ITransformer"/>.
/// </summary>
public ITransformer Fit(IDataView input)
{
_host.CheckValue(input, nameof(input));
using (var ch = _host.Start("Selecting Slots"))
{
ch.Info("Computing mutual information");
var sw = new Stopwatch();
sw.Start();
var colSet = new HashSet<string>();
foreach (var col in _columns)
{
if (!colSet.Add(col.inputColumnName))
ch.Warning("Column '{0}' specified multiple time.", col);
}
var colArr = colSet.ToArray();
var colSizes = new int[colArr.Length];
var scores = MutualInformationFeatureSelectionUtils.TrainCore(_host, input, _labelColumnName, colArr, _numBins, colSizes);
sw.Stop();
ch.Info("Finished mutual information computation in {0}", sw.Elapsed);
ch.Info("Selecting features to drop");
var threshold = ComputeThreshold(scores, _slotsInOutput, out int tiedScoresToKeep);
// If no slots should be dropped in a column, use CopyColumn to generate the corresponding output column.
SlotsDroppingTransformer.ColumnOptions[] dropSlotsColumns;
(string outputColumnName, string inputColumnName)[] copyColumnPairs;
CreateDropAndCopyColumns(colArr.Length, scores, threshold, tiedScoresToKeep, _columns.Where(col => colSet.Contains(col.inputColumnName)).ToArray(), out int[] selectedCount, out dropSlotsColumns, out copyColumnPairs);
for (int i = 0; i < selectedCount.Length; i++)
ch.Info("Selected {0} slots out of {1} in column '{2}'", selectedCount[i], colSizes[i], colArr[i]);
ch.Info("Total number of slots selected: {0}", selectedCount.Sum());
if (dropSlotsColumns.Length <= 0)
return new ColumnCopyingTransformer(_host, copyColumnPairs);
else if (copyColumnPairs.Length <= 0)
return new SlotsDroppingTransformer(_host, dropSlotsColumns);
var transformerChain = new TransformerChain<SlotsDroppingTransformer>(
new ITransformer[] {
new ColumnCopyingTransformer(_host, copyColumnPairs),
new SlotsDroppingTransformer(_host, dropSlotsColumns)
});
return transformerChain;
}
}
/// <summary>
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
/// Used for schema propagation and verification in a pipeline.
/// </summary>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
if (!inputSchema.TryFindColumn(_labelColumnName, out var label))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "label", $"Label column '{_labelColumnName}' not found in input schema");
if (!(label.IsKey || MutualInformationFeatureSelectionUtils.IsValidColumnType(label.ItemType)))
{
throw _host.ExceptUserArg(nameof(inputSchema),
$"Label column '{_labelColumnName}' does not have compatible type. Expected types are float, double, int, bool and key.");
}
var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colPair in _columns)
{
if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName);
if (!MutualInformationFeatureSelectionUtils.IsValidColumnType(col.ItemType))
{
throw _host.ExceptUserArg(nameof(inputSchema),
"Column '{0}' does not have compatible type. Expected types are float, double, int, bool and key.", colPair.inputColumnName);
}
if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
throw _host.ExceptUserArg(nameof(inputSchema), $"Variable length column '{col.Name}' is not allowed");
var metadata = new List<SchemaShape.Column>();
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
metadata.Add(slotMeta);
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.CategoricalSlotRanges, out var categoricalSlotMeta))
metadata.Add(categoricalSlotMeta);
if (col.IsNormalized() && col.Annotations.TryFindColumn(AnnotationUtils.Kinds.IsNormalized, out var isNormalizedAnnotation))
metadata.Add(isNormalizedAnnotation);
result[colPair.outputColumnName] = new SchemaShape.Column(colPair.outputColumnName, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray()));
}
return new SchemaShape(result.Values);
}
/// <summary>
/// Create method corresponding to SignatureDataTransform.
/// </summary>
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(RegistrationName);
host.CheckValue(options, nameof(options));
host.CheckValue(input, nameof(input));
host.CheckNonEmpty(options.Columns, nameof(options.Columns));
host.CheckUserArg(options.SlotsInOutput > 0, nameof(options.SlotsInOutput));
host.CheckNonWhiteSpace(options.LabelColumnName, nameof(options.LabelColumnName));
host.Check(options.NumBins > 1, "numBins must be greater than 1.");
(string outputColumnName, string inputColumnName)[] cols = options.Columns.Select(col => (col, col)).ToArray();
return new MutualInformationFeatureSelectingEstimator(env, options.LabelColumnName, options.SlotsInOutput, options.NumBins, cols).Fit(input).Transform(input) as IDataTransform;
}
/// <summary>
/// Computes the threshold for the scores such that the top k slots are preserved.
/// If there are less than k scores greater than zero, the threshold is set to zero and
/// the tiedScoresToKeep is set to zero, so that we only keep scores strictly greater than zero.
/// </summary>
/// <param name="scores">The score for each column and each slot.</param>
/// <param name="topk">How many slots to preserve.</param>
/// <param name="tiedScoresToKeep">If there are ties, how many of them to keep.</param>
/// <returns>The threshold.</returns>
private static float ComputeThreshold(float[][] scores, int topk, out int tiedScoresToKeep)
{
// Use a min-heap for the topk elements.
var heap = new Heap<float>((f1, f2) => f1 > f2, topk);
for (int i = 0; i < scores.Length; i++)
{
for (int j = 0; j < scores[i].Length; j++)
{
var score = scores[i][j];
Contracts.Assert(score >= 0);
if (heap.Count < topk)
{
if (score > 0)
heap.Add(score);
}
else if (heap.Top < score)
{
Contracts.Assert(heap.Count == topk);
heap.Pop();
heap.Add(score);
}
}
}
var threshold = heap.Count < topk ? 0 : heap.Top;
tiedScoresToKeep = 0;
if (threshold == 0)
return threshold;
while (heap.Count > 0)
{
var top = heap.Pop();
Contracts.Assert(top >= threshold);
if (top > threshold)
break;
tiedScoresToKeep++;
}
return threshold;
}
private static void CreateDropAndCopyColumns(int size, float[][] scores, float threshold, int tiedScoresToKeep, (string outputColumnName, string inputColumnName)[] cols,
out int[] selectedCount, out SlotsDroppingTransformer.ColumnOptions[] dropSlotsColumns, out (string outputColumnName, string inputColumnName)[] copyColumnsPairs)
{
Contracts.Assert(size > 0);
Contracts.Assert(Utils.Size(scores) == size);
Contracts.Assert(Utils.Size(cols) == size);
Contracts.Assert(threshold > 0 || (threshold == 0 && tiedScoresToKeep == 0));
var dropCols = new List<SlotsDroppingTransformer.ColumnOptions>();
var copyCols = new List<(string outputColumnName, string inputColumnName)>();
selectedCount = new int[scores.Length];
for (int i = 0; i < size; i++)
{
var slots = new List<(int min, int? max)>();
var score = scores[i];
selectedCount[i] = 0;
for (int j = 0; j < score.Length; j++)
{
var sc = score[j];
if (sc > threshold)
{
selectedCount[i]++;
continue;
}
if (sc == threshold && tiedScoresToKeep > 0)
{
tiedScoresToKeep--;
selectedCount[i]++;
continue;
}
// Adjacent slots are combined into a float range.
int min = j;
while (++j < score.Length)
{
sc = score[j];
if (sc > threshold)
{
selectedCount[i]++;
break;
}
if (sc == threshold && tiedScoresToKeep > 0)
{
tiedScoresToKeep--;
selectedCount[i]++;
break;
}
}
int max = j - 1;
slots.Add((min, max));
}
if (slots.Count <= 0)
copyCols.Add(cols[i]);
else
dropCols.Add(new SlotsDroppingTransformer.ColumnOptions(cols[i].outputColumnName, cols[i].inputColumnName, slots.ToArray()));
}
dropSlotsColumns = dropCols.ToArray();
copyColumnsPairs = copyCols.ToArray();
}
}
internal static class MutualInformationFeatureSelectionUtils
{
/// <summary>
/// Returns the feature selection scores for each slot of each column.
/// </summary>
/// <param name="host">The host.</param>
/// <param name="input">The input dataview.</param>
/// <param name="labelColumnName">The label column.</param>
/// <param name="columns">The columns for which to compute the feature selection scores.</param>
/// <param name="numBins">The number of bins to use for numeric features.</param>
/// <param name="colSizes">The columns' sizes before dropping any slots.</param>
/// <returns>A list of scores for each column and each slot.</returns>
internal static float[][] TrainCore(IHost host, IDataView input, string labelColumnName, string[] columns, int numBins, int[] colSizes)
{
var impl = new Impl(host);
return impl.GetScores(input, labelColumnName, columns, numBins, colSizes);
}
internal static bool IsValidColumnType(DataViewType type)
{
// REVIEW: Consider supporting all integer and unsigned types.
ulong keyCount = type.GetKeyCount();
return
(0 < keyCount && keyCount < Utils.ArrayMaxSize) || type is BooleanDataViewType ||
type == NumberDataViewType.Single || type == NumberDataViewType.Double || type == NumberDataViewType.Int32;
}
private sealed class Impl
{
private static readonly FuncStaticMethodInfo1<DataViewType, Delegate> _makeKeyMapperMethodInfo
= new FuncStaticMethodInfo1<DataViewType, Delegate>(MakeKeyMapper<int>);
private readonly IHost _host;
private readonly BinFinderBase _binFinder;
private int _numBins;
private VBuffer<int> _labels; // always dense
private int _numLabels;
private int[][] _contingencyTable;
private int[] _labelSums;
private int[] _featureSums;
private readonly List<float> _singles;
private readonly List<double> _doubles;
private ValueMapper<VBuffer<bool>, VBuffer<int>> _boolMapper;
public Impl(IHost host)
{
Contracts.AssertValue(host);
_host = host;
_binFinder = new GreedyBinFinder();
_singles = new List<float>();
_doubles = new List<double>();
}
public float[][] GetScores(IDataView input, string labelColumnName, string[] columns, int numBins, int[] colSizes)
{
_numBins = numBins;
var schema = input.Schema;
var size = columns.Length;
if (!schema.TryGetColumnIndex(labelColumnName, out int labelCol))
{
throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.LabelColumnName),
"Label column '{0}' not found", labelColumnName);
}
var labelType = schema[labelCol].Type;
if (!IsValidColumnType(labelType))
{
throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.LabelColumnName),
"Label column '{0}' does not have compatible type", labelColumnName);
}
var colSrcs = new int[size + 1];
colSrcs[size] = labelCol;
for (int i = 0; i < size; i++)
{
var colName = columns[i];
if (!schema.TryGetColumnIndex(colName, out int colSrc))
{
throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.Columns),
"Source column '{0}' not found", colName);
}
var colType = schema[colSrc].Type;
if (colType is VectorDataViewType vectorType && !vectorType.IsKnownSize)
{
throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.Columns),
"Variable length column '{0}' is not allowed", colName);
}
if (!IsValidColumnType(colType.GetItemType()))
{
throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.Columns),
"Column '{0}' of type '{1}' does not have compatible type.", colName, colType);
}
colSrcs[i] = colSrc;
colSizes[i] = colType.GetValueCount();
}
var scores = new float[size][];
using (var ch = _host.Start("Computing mutual information scores"))
using (var pch = _host.StartProgressChannel("Computing mutual information scores"))
{
using (var trans = Transposer.Create(_host, input, false, colSrcs))
{
int i = 0;
var header = new ProgressHeader(new[] { "columns" });
var b = trans.Schema.TryGetColumnIndex(labelColumnName, out labelCol);
Contracts.Assert(b);
GetLabels(trans, labelType, labelCol);
_contingencyTable = new int[_numLabels][];
_labelSums = new int[_numLabels];
pch.SetHeader(header, e => e.SetProgress(0, i, size));
for (i = 0; i < size; i++)
{
b = trans.Schema.TryGetColumnIndex(columns[i], out int col);
Contracts.Assert(b);
ch.Trace("Computing scores for column '{0}'", columns[i]);
scores[i] = ComputeMutualInformation(trans, col);
#if DEBUG
ch.Trace("Scores for column '{0}': {1}", columns[i], string.Join(", ", scores[i]));
#endif
pch.Checkpoint(i + 1);
}
}
}
return scores;
}
private void GetLabels(Transposer trans, DataViewType labelType, int labelCol)
{
int min;
int lim;
var labels = default(VBuffer<int>);
// Note: NAs have their own separate bin.
if (labelType == NumberDataViewType.Int32)
{
var tmp = default(VBuffer<int>);
trans.GetSingleSlotValue(labelCol, ref tmp);
BinInts(in tmp, ref labels, _numBins, out min, out lim);
_numLabels = lim - min;
}
else if (labelType == NumberDataViewType.Single)
{
var tmp = default(VBuffer<Single>);
trans.GetSingleSlotValue(labelCol, ref tmp);
BinSingles(in tmp, ref labels, _numBins, out min, out lim);
_numLabels = lim - min;
}
else if (labelType == NumberDataViewType.Double)
{
var tmp = default(VBuffer<Double>);
trans.GetSingleSlotValue(labelCol, ref tmp);
BinDoubles(in tmp, ref labels, _numBins, out min, out lim);
_numLabels = lim - min;
}
else if (labelType is BooleanDataViewType)
{
var tmp = default(VBuffer<bool>);
trans.GetSingleSlotValue(labelCol, ref tmp);
BinBools(in tmp, ref labels);
_numLabels = 3;
min = -1;
lim = 2;
}
else
{
ulong labelKeyCount = labelType.GetKeyCount();
Contracts.Assert(labelKeyCount < Utils.ArrayMaxSize);
KeyLabelGetter<int> del = GetKeyLabels<int>;
var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(labelType.RawType);
var parameters = new object[] { trans, labelCol, labelType };
_labels = (VBuffer<int>)methodInfo.Invoke(this, parameters);
_numLabels = labelType.GetKeyCountAsInt32(_host) + 1;
// No need to densify or shift in this case.
return;
}
// Densify and shift labels.
VBufferUtils.Densify(ref labels);
Contracts.Assert(labels.IsDense);
var labelsEditor = VBufferEditor.CreateFromBuffer(ref labels);
for (int i = 0; i < labels.Length; i++)
{
labelsEditor.Values[i] -= min;
Contracts.Assert(labelsEditor.Values[i] < _numLabels);
}
_labels = labelsEditor.Commit();
}
private delegate VBuffer<int> KeyLabelGetter<T>(Transposer trans, int labelCol, DataViewType labeColumnType);
private VBuffer<int> GetKeyLabels<T>(Transposer trans, int labelCol, DataViewType labelColumnType)
{
var tmp = default(VBuffer<T>);
var labels = default(VBuffer<int>);
trans.GetSingleSlotValue(labelCol, ref tmp);
BinKeys<T>(labelColumnType)(in tmp, ref labels);
VBufferUtils.Densify(ref labels);
return labels;
}
/// <summary>
/// Computes the mutual information for one column.
/// </summary>
private Single[] ComputeMutualInformation(Transposer trans, int col)
{
// Note: NAs have their own separate bin.
var type = trans.Schema[col].Type;
var itemType = type.GetItemType();
if (itemType == NumberDataViewType.Int32)
{
return ComputeMutualInformation(trans, col,
(ref VBuffer<int> src, ref VBuffer<int> dst, out int min, out int lim) =>
{
BinInts(in src, ref dst, _numBins, out min, out lim);
});
}
if (itemType == NumberDataViewType.Single)
{
return ComputeMutualInformation(trans, col,
(ref VBuffer<Single> src, ref VBuffer<int> dst, out int min, out int lim) =>
{
BinSingles(in src, ref dst, _numBins, out min, out lim);
});
}
if (itemType == NumberDataViewType.Double)
{
return ComputeMutualInformation(trans, col,
(ref VBuffer<Double> src, ref VBuffer<int> dst, out int min, out int lim) =>
{
BinDoubles(in src, ref dst, _numBins, out min, out lim);
});
}
if (itemType is BooleanDataViewType)
{
return ComputeMutualInformation(trans, col,
(ref VBuffer<bool> src, ref VBuffer<int> dst, out int min, out int lim) =>
{
min = -1;
lim = 2;
BinBools(in src, ref dst);
});
}
ulong keyCount = itemType.GetKeyCount();
Contracts.Assert(keyCount < Utils.ArrayMaxSize);
var mapper = Utils.MarshalInvoke(_makeKeyMapperMethodInfo, itemType.RawType, itemType);
ComputeMutualInformationDelegate<int> cmiDel = ComputeMutualInformation;
var cmiMethodInfo = cmiDel.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(itemType.RawType);
return (Single[])cmiMethodInfo.Invoke(this, new object[] { trans, col, mapper });
}
private delegate float[] ComputeMutualInformationDelegate<T>(Transposer trans, int col, Mapper<T> mapper);
private delegate void Mapper<T>(ref VBuffer<T> src, ref VBuffer<int> dst, out int min, out int lim);
private static Mapper<T> MakeKeyMapper<T>(DataViewType type)
{
ulong keyCount = type.GetKeyCount();
Contracts.Assert(0 < keyCount && keyCount < Utils.ArrayMaxSize);
var mapper = BinKeys<T>(type);
return
(ref VBuffer<T> src, ref VBuffer<int> dst, out int min, out int lim) =>
{
min = 0;
lim = (int)type.GetKeyCount() + 1;
mapper(in src, ref dst);
};
}
/// <summary>
/// Computes the mutual information for one column.
/// </summary>
private float[] ComputeMutualInformation<T>(Transposer trans, int col, Mapper<T> mapper)
{
var slotCount = trans.Schema[col].Type.GetValueCount();
var scores = new float[slotCount];
int iScore = 0;
VBuffer<int> slotValues = default(VBuffer<int>);
using (var cursor = trans.GetSlotCursor(col))
{
var getter = cursor.GetGetter<T>();
while (cursor.MoveNext())
{
VBuffer<T> tmp = default(VBuffer<T>);
getter(ref tmp);
mapper(ref tmp, ref slotValues, out int min, out int lim);
Contracts.Assert(iScore < slotCount);
scores[iScore++] = ComputeMutualInformation(in slotValues, lim - min, min);
}
}
return scores;
}
/// <summary>
/// Computes the mutual information for one slot.
/// </summary>
private float ComputeMutualInformation(in VBuffer<int> features, int numFeatures, int offset)
{
Contracts.Assert(_labels.Length == features.Length);
if (Utils.Size(_contingencyTable[0]) < numFeatures)
{
for (int i = 0; i < _numLabels; i++)
Array.Resize(ref _contingencyTable[i], numFeatures);
Array.Resize(ref _featureSums, numFeatures);
}
for (int i = 0; i < _numLabels; i++)
Array.Clear(_contingencyTable[i], 0, numFeatures);
Array.Clear(_labelSums, 0, _numLabels);
Array.Clear(_featureSums, 0, numFeatures);
FillTable(in features, offset, numFeatures);
for (int i = 0; i < _numLabels; i++)
{
for (int j = 0; j < numFeatures; j++)
{
_labelSums[i] += _contingencyTable[i][j];
_featureSums[j] += _contingencyTable[i][j];
}
}
double score = 0;
for (int i = 0; i < _numLabels; i++)
{
for (int j = 0; j < numFeatures; j++)
{
if (_contingencyTable[i][j] > 0)
score += _contingencyTable[i][j] / (double)_labels.Length * Math.Log(_contingencyTable[i][j] * (double)_labels.Length / ((double)_labelSums[i] * _featureSums[j]), 2);
}
}
Contracts.Assert(score >= 0);
return (float)score;
}
/// <summary>
/// Fills the contingency table.
/// </summary>
private void FillTable(in VBuffer<int> features, int offset, int numFeatures)
{
Contracts.Assert(_labels.IsDense);
Contracts.Assert(_labels.Length == features.Length);
var featureValues = features.GetValues();
var labelsValues = _labels.GetValues();
if (features.IsDense)
{
for (int i = 0; i < labelsValues.Length; i++)
{
var label = labelsValues[i];
var feature = featureValues[i] - offset;
Contracts.Assert(0 <= label && label < _numLabels);
Contracts.Assert(0 <= feature && feature < numFeatures);
_contingencyTable[label][feature]++;
}
return;
}
var featureIndices = features.GetIndices();
int ii = 0;
for (int i = 0; i < labelsValues.Length; i++)
{
var label = labelsValues[i];
int feature;
if (ii == featureIndices.Length || i < featureIndices[ii])
feature = -offset;
else
{
feature = featureValues[ii] - offset;
ii++;
}
Contracts.Assert(0 <= label && label < _numLabels);
Contracts.Assert(0 <= feature && feature < numFeatures);
_contingencyTable[label][feature]++;
}
Contracts.Assert(ii == featureIndices.Length);
}
/// <summary>
/// Maps from keys to ints.
/// </summary>
private static ValueMapper<VBuffer<T>, VBuffer<int>> BinKeys<T>(DataViewType colType)
{
var conv = Data.Conversion.Conversions.DefaultInstance.GetStandardConversion<T, uint>(colType, NumberDataViewType.UInt32, out bool identity);
ValueMapper<T, int> mapper;
if (identity)
{
mapper = (ValueMapper<T, int>)(Delegate)(ValueMapper<uint, int>)(
(in uint src, ref int dst) =>
{
dst = (int)src;
});
}
else
{
mapper =
(in T src, ref int dst) =>
{
uint t = 0;
conv(in src, ref t);
dst = (int)t;
};
}
return CreateVectorMapper(mapper);
}
/// <summary>
/// Maps Ints.
/// </summary>
private void BinInts(in VBuffer<int> input, ref VBuffer<int> output,
int numBins, out int min, out int lim)
{
Contracts.Assert(_singles.Count == 0);
var bounds = _binFinder.FindBins(numBins, _singles, input.Length - input.GetValues().Length);
min = -1 - bounds.FindIndexSorted(0);
lim = min + bounds.Length + 1;
int offset = min;
ValueMapper<int, int> mapper =
(in int src, ref int dst) =>
dst = offset + 1 + bounds.FindIndexSorted((Single)src);
mapper.MapVector(in input, ref output);
_singles.Clear();
}
/// <summary>
/// Maps from Singles to ints. NaNs (and only NaNs) are mapped to the first bin.
/// </summary>
private void BinSingles(in VBuffer<Single> input, ref VBuffer<int> output,
int numBins, out int min, out int lim)
{
Contracts.Assert(_singles.Count == 0);
var inputValues = input.GetValues();
for (int i = 0; i < inputValues.Length; i++)
{
var val = inputValues[i];
if (!Single.IsNaN(val))
_singles.Add(val);
}
var bounds = _binFinder.FindBins(numBins, _singles, input.Length - inputValues.Length);
min = -1 - bounds.FindIndexSorted(0);
lim = min + bounds.Length + 1;
int offset = min;
ValueMapper<Single, int> mapper =
(in Single src, ref int dst) =>
dst = Single.IsNaN(src) ? offset : offset + 1 + bounds.FindIndexSorted(src);
mapper.MapVector(in input, ref output);
_singles.Clear();
}
/// <summary>
/// Maps from Doubles to ints. NaNs (and only NaNs) are mapped to the first bin.
/// </summary>
private void BinDoubles(in VBuffer<Double> input, ref VBuffer<int> output,
int numBins, out int min, out int lim)
{
Contracts.Assert(_doubles.Count == 0);
var inputValues = input.GetValues();
for (int i = 0; i < inputValues.Length; i++)
{
var val = inputValues[i];
if (!Double.IsNaN(val))
_doubles.Add(val);
}
var bounds = _binFinder.FindBins(numBins, _doubles, input.Length - inputValues.Length);
var offset = min = -1 - bounds.FindIndexSorted(0);
lim = min + bounds.Length + 1;
ValueMapper<Double, int> mapper =
(in Double src, ref int dst) =>
dst = Double.IsNaN(src) ? offset : offset + 1 + bounds.FindIndexSorted(src);
mapper.MapVector(in input, ref output);
_doubles.Clear();
}
private void BinBools(in VBuffer<bool> input, ref VBuffer<int> output)
{
if (_boolMapper == null)
_boolMapper = CreateVectorMapper<bool, int>(BinOneBool);
_boolMapper(in input, ref output);
}
private void BinOneBool(in bool src, ref int dst)
{
dst = Convert.ToInt32(src);
}
}
/// <summary>
/// Given a mapper from T to int, creates a mapper from VBuffer{T} to VBuffer<int>.
/// Assumes that the mapper maps default(TSrc) to default(TDst) so that the returned mapper preserves sparsity.
/// </summary>
private static ValueMapper<VBuffer<TSrc>, VBuffer<TDst>> CreateVectorMapper<TSrc, TDst>(ValueMapper<TSrc, TDst> map)
where TDst : IEquatable<TDst>
{
#if DEBUG
TSrc tmpSrc = default(TSrc);
TDst tmpDst = default(TDst);
map(in tmpSrc, ref tmpDst);
Contracts.Assert(tmpDst.Equals(default(TDst)));
#endif
return map.MapVector;
}
private static void MapVector<TSrc, TDst>(this ValueMapper<TSrc, TDst> map, in VBuffer<TSrc> input, ref VBuffer<TDst> output)
{
var inputValues = input.GetValues();
var editor = VBufferEditor.Create(ref output, input.Length, inputValues.Length);
for (int i = 0; i < inputValues.Length; i++)
{
TSrc val = inputValues[i];
map(in val, ref editor.Values[i]);
}
if (!input.IsDense && inputValues.Length > 0)
{
input.GetIndices().CopyTo(editor.Indices);
}
output = editor.Commit();
}
}
}
|