|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Trainers.Ensemble
{
internal abstract class BaseDiverseSelector<TOutput, TDiversityMetric> : SubModelDataSelector<TOutput>
where TDiversityMetric : class, IDiversityMeasure<TOutput>
{
public abstract class DiverseSelectorArguments : ArgumentsBase
{
}
private readonly IComponentFactory<IDiversityMeasure<TOutput>> _diversityMetricType;
private readonly ConcurrentDictionary<FeatureSubsetModel<TOutput>, TOutput[]> _predictions;
private protected BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name,
IComponentFactory<IDiversityMeasure<TOutput>> diversityMetricType)
: base(args, env, name)
{
_diversityMetricType = diversityMetricType;
_predictions = new ConcurrentDictionary<FeatureSubsetModel<TOutput>, TOutput[]>();
}
protected IDiversityMeasure<TOutput> CreateDiversityMetric()
{
return _diversityMetricType.CreateComponent(Host);
}
public override void CalculateMetrics(FeatureSubsetModel<TOutput> model,
ISubsetSelector subsetSelector, Subset subset, Batch batch, bool needMetrics)
{
base.CalculateMetrics(model, subsetSelector, subset, batch, needMetrics);
var vm = model.Predictor as IValueMapper;
Host.Check(vm != null, "Predictor doesn't implement the expected interface");
var map = vm.GetMapper<VBuffer<Single>, TOutput>();
TOutput[] preds = new TOutput[100];
int count = 0;
var data = subsetSelector.GetTestData(subset, batch);
using (var cursor = new FeatureFloatVectorCursor(data, CursOpt.AllFeatures))
{
while (cursor.MoveNext())
{
Utils.EnsureSize(ref preds, count + 1);
map(in cursor.Features, ref preds[count]);
count++;
}
}
Array.Resize(ref preds, count);
_predictions[model] = preds;
}
/// <summary>
/// This calculates the diversity by calculating the disagreement measure which is defined as the sum of number of instances correctly(incorrectly)
/// classified by first classifier and incorrectly(correctly) classified by the second classifier over the total number of instances.
/// All the pairwise classifiers are sorted out to take the most divers classifiers.
/// </summary>
/// <param name="models"></param>
/// <returns></returns>
public override IList<FeatureSubsetModel<TOutput>> Prune(IList<FeatureSubsetModel<TOutput>> models)
{
if (models.Count <= 1)
return models;
// 1. Find the disagreement number
List<ModelDiversityMetric<TOutput>> diversityValues = CalculateDiversityMeasure(models, _predictions);
_predictions.Clear();
// 2. Sort all the pairwise classifiers
var sortedModels = diversityValues.ToArray();
Array.Sort(sortedModels, new ModelDiversityComparer());
var modelCountToBeSelected = (int)(models.Count * LearnersSelectionProportion);
if (modelCountToBeSelected == 0)
modelCountToBeSelected++;
// 3. Take the most diverse classifiers
var selectedModels = new List<FeatureSubsetModel<TOutput>>();
foreach (var item in sortedModels)
{
if (selectedModels.Count < modelCountToBeSelected)
{
if (!selectedModels.Contains(item.ModelX))
{
selectedModels.Add(item.ModelX);
}
}
if (selectedModels.Count < modelCountToBeSelected)
{
if (!selectedModels.Contains(item.ModelY))
{
selectedModels.Add(item.ModelY);
continue;
}
}
else
{
break;
}
}
return selectedModels;
}
public abstract List<ModelDiversityMetric<TOutput>> CalculateDiversityMeasure(IList<FeatureSubsetModel<TOutput>> models,
ConcurrentDictionary<FeatureSubsetModel<TOutput>, TOutput[]> predictions);
public class ModelDiversityComparer : IComparer<ModelDiversityMetric<TOutput>>
{
public int Compare(ModelDiversityMetric<TOutput> x, ModelDiversityMetric<TOutput> y)
{
if (x == null || y == null)
return 0;
if (x.DiversityNumber > y.DiversityNumber)
return -1;
if (y.DiversityNumber > x.DiversityNumber)
return 1;
return 0;
}
}
}
}
|