|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Transforms
{
internal static class PermutationFeatureImportance<TModel, TMetric, TResult> where TResult : IMetricsStatistics<TMetric>
where TModel : class
{
public static ImmutableArray<TResult>
GetImportanceMetricsMatrix(
IHostEnvironment env,
IPredictionTransformer<TModel> model,
IDataView data,
Func<TResult> resultInitializer,
Func<IDataView, TMetric> evaluationFunc,
Func<TMetric, TMetric, TMetric> deltaFunc,
string features,
int permutationCount,
bool useFeatureWeightFilter = false,
int? topExamples = null)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(nameof(PermutationFeatureImportance<TModel, TMetric, TResult>));
host.CheckValue(model, nameof(model));
host.CheckValue(data, nameof(data));
host.CheckNonEmpty(features, nameof(features));
topExamples = topExamples ?? Utils.ArrayMaxSize;
host.Check(topExamples > 0, "Provide how many examples to use (positive number) or set to null to use whole dataset.");
VBuffer<ReadOnlyMemory<char>> slotNames = default;
var metricsDelta = new List<TResult>();
using (var ch = host.Start("GetImportanceMetrics"))
{
ch.Trace("Scoring and evaluating baseline.");
var baselineMetrics = evaluationFunc(model.Transform(data));
// Get slot names.
var featuresColumn = data.Schema[features];
int numSlots = featuresColumn.Type.GetVectorSize();
data.Schema.TryGetColumnIndex(features, out int featuresColumnIndex);
ch.Info("Number of slots: " + numSlots);
if (data.Schema[featuresColumnIndex].HasSlotNames(numSlots))
data.Schema[featuresColumnIndex].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref slotNames);
if (slotNames.Length != numSlots)
slotNames = VBufferUtils.CreateEmpty<ReadOnlyMemory<char>>(numSlots);
VBuffer<float> weights = default;
var workingFeatureIndices = Enumerable.Range(0, numSlots).ToList();
int zeroWeightsCount = 0;
// By default set to the number of all features available.
var evaluatedFeaturesCount = numSlots;
if (useFeatureWeightFilter)
{
var predictorWithWeights = model.Model as IPredictorWithFeatureWeights<Single>;
if (predictorWithWeights != null)
{
predictorWithWeights.GetFeatureWeights(ref weights);
const int maxReportedZeroFeatures = 10;
StringBuilder msgFilteredOutFeatures = new StringBuilder("The following features have zero weight and will not be evaluated: \n \t");
var prefix = "";
foreach (var k in weights.Items(all: true))
{
if (k.Value == 0)
{
zeroWeightsCount++;
// Print info about first few features we're not going to evaluate.
if (zeroWeightsCount <= maxReportedZeroFeatures)
{
msgFilteredOutFeatures.Append(prefix);
msgFilteredOutFeatures.Append(GetSlotName(slotNames, k.Key));
prefix = ", ";
}
}
else
workingFeatureIndices.Add(k.Key);
}
// Old FastTree models has less weights than slots.
if (weights.Length < numSlots)
{
ch.Warning(
"Predictor had fewer features than slots. All unknown features will get default 0 weight.");
zeroWeightsCount += numSlots - weights.Length;
var indexes = weights.GetIndices().ToArray();
var values = weights.GetValues().ToArray();
var count = values.Length;
weights = new VBuffer<float>(numSlots, count, values, indexes);
}
evaluatedFeaturesCount = workingFeatureIndices.Count;
ch.Info("Number of zero weights: {0} out of {1}.", zeroWeightsCount, weights.Length);
// Print what features have 0 weight
if (zeroWeightsCount > 0)
{
if (zeroWeightsCount > maxReportedZeroFeatures)
{
msgFilteredOutFeatures.Append(string.Format("... (printing out {0} features here).\n Use 'Index' column in the report for info on what features are not evaluated.", maxReportedZeroFeatures));
}
ch.Info(msgFilteredOutFeatures.ToString());
}
}
}
if (workingFeatureIndices.Count == 0 && zeroWeightsCount == 0)
{
// Use all features otherwise.
workingFeatureIndices.AddRange(Enumerable.Range(0, numSlots));
}
if (zeroWeightsCount == numSlots)
{
ch.Warning("All features have 0 weight thus can not do thorough evaluation");
return metricsDelta.ToImmutableArray();
}
// Note: this will not work on the huge dataset.
var maxSize = topExamples;
List<float> initialfeatureValuesList = new List<float>();
// Cursor through the data to cache slot 0 values for the upcoming permutation.
var valuesRowCount = 0;
// REVIEW: Seems like if the labels are NaN, so that all metrics are NaN, this command will be useless.
// In which case probably erroring out is probably the most useful thing.
using (var cursor = data.GetRowCursor(featuresColumn))
{
var featuresGetter = cursor.GetGetter<VBuffer<float>>(featuresColumn);
var featuresBuffer = default(VBuffer<float>);
while (initialfeatureValuesList.Count < maxSize && cursor.MoveNext())
{
featuresGetter(ref featuresBuffer);
initialfeatureValuesList.Add(featuresBuffer.GetItemOrDefault(workingFeatureIndices[0]));
}
valuesRowCount = initialfeatureValuesList.Count;
}
if (valuesRowCount > 0)
{
ch.Info("Detected {0} examples for evaluation.", valuesRowCount);
}
else
{
ch.Warning("Detected no examples for evaluation.");
return metricsDelta.ToImmutableArray();
}
float[] featureValuesBuffer = initialfeatureValuesList.ToArray();
float[] nextValues = new float[valuesRowCount];
// Now iterate through all the working slots, do permutation and calc the delta of metrics.
int processedCnt = 0;
int nextFeatureIndex = 0;
var shuffleRand = RandomUtils.Create(host.Rand.Next());
using (var pch = host.StartProgressChannel("Calculating Permutation Feature Importance"))
{
pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt));
foreach (var workingIndx in workingFeatureIndices)
{
// Index for the feature we will permute next. Needed to build in advance a buffer for the permutation.
if (processedCnt < workingFeatureIndices.Count - 1)
nextFeatureIndex = workingFeatureIndices[processedCnt + 1];
// Used for pre-caching the next feature
int nextValuesIndex = 0;
SchemaDefinition input = SchemaDefinition.Create(typeof(FeaturesBuffer));
Contracts.Assert(input.Count == 1);
input[0].ColumnName = features;
SchemaDefinition output = SchemaDefinition.Create(typeof(FeaturesBuffer));
Contracts.Assert(output.Count == 1);
output[0].ColumnName = features;
output[0].ColumnType = featuresColumn.Type;
// Perform multiple permutations for one feature to build a confidence interval
var metricsDeltaForFeature = resultInitializer();
for (int permutationIteration = 0; permutationIteration < permutationCount; permutationIteration++)
{
Utils.Shuffle<float>(shuffleRand, featureValuesBuffer);
Action<FeaturesBuffer, FeaturesBuffer, PermuterState> permuter =
(src, dst, state) =>
{
src.Features.CopyTo(ref dst.Features);
VBufferUtils.ApplyAt(ref dst.Features, workingIndx,
(int ii, ref float d) =>
d = featureValuesBuffer[state.SampleIndex++]);
// Is it time to pre-cache the next feature?
if (permutationIteration == permutationCount - 1 &&
processedCnt < workingFeatureIndices.Count - 1)
{
// Fill out the featureValueBuffer for the next feature while updating the current feature
// This is the reason I need PermuterState in LambdaTransform.CreateMap.
nextValues[nextValuesIndex] = src.Features.GetItemOrDefault(nextFeatureIndex);
if (nextValuesIndex < valuesRowCount - 1)
nextValuesIndex++;
}
};
IDataView viewPermuted = LambdaTransform.CreateMap(
host, data, permuter, null, input, output);
if (valuesRowCount == topExamples)
viewPermuted = SkipTakeFilter.Create(host, new SkipTakeFilter.TakeOptions() { Count = valuesRowCount }, viewPermuted);
var metrics = evaluationFunc(model.Transform(viewPermuted));
var delta = deltaFunc(metrics, baselineMetrics);
metricsDeltaForFeature.Add(delta);
}
// Add the metrics delta to the list
metricsDelta.Add(metricsDeltaForFeature);
// Swap values for next iteration of permutation.
if (processedCnt < workingFeatureIndices.Count - 1)
{
Array.Clear(featureValuesBuffer, 0, featureValuesBuffer.Length);
nextValues.CopyTo(featureValuesBuffer, 0);
Array.Clear(nextValues, 0, nextValues.Length);
}
processedCnt++;
}
pch.Checkpoint(processedCnt, processedCnt);
}
}
return metricsDelta.ToImmutableArray();
}
private static ReadOnlyMemory<char> GetSlotName(VBuffer<ReadOnlyMemory<char>> slotNames, int index)
{
var slotName = slotNames.GetItemOrDefault(index);
return slotName.IsEmpty
? slotName
: string.Format("f{0}", index).AsMemory();
}
/// <summary>
/// This is used as a hack to force Lambda Transform behave sequentially.
/// </summary>
private sealed class PermuterState
{
public int SampleIndex;
}
/// <summary>
/// Helper structure used for features permutation in Lambda Transform.
/// </summary>
private sealed class FeaturesBuffer
{
public VBuffer<float> Features;
}
/// <summary>
/// Helper class for report's Lambda transform.
/// </summary>
private sealed class FeatureIndex
{
#pragma warning disable 0649
public int Index;
#pragma warning restore 0649
}
/// <summary>
/// One more helper class for report's Lambda transform.
/// </summary>
private sealed class FeatureName
{
#pragma warning disable 0649
public ReadOnlyMemory<char> Name;
#pragma warning restore 0649
}
}
}
|