|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
namespace Microsoft.ML.Trainers.FastTree
{
internal class BaggingProvider
{
protected Dataset CompleteTrainingSet;
protected DocumentPartitioning CurrentTrainPartition;
protected DocumentPartitioning CurrentOutOfBagPartition;
protected Random RndGenerator;
protected int MaxLeaves;
protected double TrainFraction;
public BaggingProvider(Dataset completeTrainingSet, int maxLeaves, int randomSeed, double trainFraction)
{
CompleteTrainingSet = completeTrainingSet;
MaxLeaves = maxLeaves;
RndGenerator = new Random(randomSeed);
TrainFraction = trainFraction;
GenerateNewBag();
}
public virtual void GenerateNewBag()
{
int[] trainDocs = new int[CompleteTrainingSet.NumDocs];
int[] outOfBagDocs = new int[CompleteTrainingSet.NumDocs];
int trainSize = 0;
int outOfBagSize = 0;
for (int i = 0; i < CompleteTrainingSet.NumQueries; i++)
{
int begin = CompleteTrainingSet.Boundaries[i];
int numDocuments = CompleteTrainingSet.Boundaries[i + 1] - begin;
for (int d = 0; d < numDocuments; d++)
{
if (RndGenerator.NextDouble() < TrainFraction)
{
trainDocs[trainSize] = begin + d;
trainSize++;
}
else
{
outOfBagDocs[outOfBagSize] = begin + d;
outOfBagSize++;
}
}
}
CurrentTrainPartition = new DocumentPartitioning(trainDocs, trainSize, MaxLeaves);
CurrentOutOfBagPartition = new DocumentPartitioning(outOfBagDocs, outOfBagSize, MaxLeaves);
CurrentTrainPartition.Initialize();
CurrentOutOfBagPartition.Initialize();
}
public DocumentPartitioning GetCurrentTrainingPartition()
{
return CurrentTrainPartition;
}
public DocumentPartitioning GetCurrentOutOfBagPartition()
{
return CurrentOutOfBagPartition;
}
public int GetBagCount(int numTrees, int bagSize)
{
return numTrees / bagSize;
}
// Divides output values of leaves to bag count.
// This brings back the final scores generated by model on a same
// range as when we didn't use bagging
internal void ScaleEnsembleLeaves(int numTrees, int bagSize, InternalTreeEnsemble ensemble)
{
int bagCount = GetBagCount(numTrees, bagSize);
for (int t = 0; t < ensemble.NumTrees; t++)
{
InternalRegressionTree tree = ensemble.GetTreeAt(t);
tree.ScaleOutputsBy(1.0 / bagCount);
}
}
}
//REVIEW: Should FastTree binary application have instances bagging or query bagging?
internal class RankingBaggingProvider : BaggingProvider
{
public RankingBaggingProvider(Dataset completeTrainingSet, int maxLeaves, int randomSeed, double trainFraction) :
base(completeTrainingSet, maxLeaves, randomSeed, trainFraction)
{
}
public override void GenerateNewBag()
{
int[] trainDocs = new int[CompleteTrainingSet.NumDocs];
int[] outOfBagDocs = new int[CompleteTrainingSet.NumDocs];
int trainSize = 0;
int outOfBagSize = 0;
int[] tmpTrainQueryIndices = new int[CompleteTrainingSet.NumQueries];
bool[] selectedTrainQueries = new bool[CompleteTrainingSet.NumQueries];
int qIdx = 0;
for (int i = 0; i < CompleteTrainingSet.NumQueries; i++)
{
int begin = CompleteTrainingSet.Boundaries[i];
int numDocuments = CompleteTrainingSet.Boundaries[i + 1] - begin;
if (RndGenerator.NextDouble() < TrainFraction)
{
for (int d = 0; d < numDocuments; d++)
{
trainDocs[trainSize] = begin + d;
trainSize++;
}
tmpTrainQueryIndices[qIdx] = i;
qIdx++;
selectedTrainQueries[i] = true;
}
}
int outOfBagQueriesCount = CompleteTrainingSet.NumQueries - qIdx;
var currentTrainQueryIndices = new int[CompleteTrainingSet.NumQueries - outOfBagQueriesCount];
Array.Copy(tmpTrainQueryIndices, currentTrainQueryIndices, currentTrainQueryIndices.Length);
var currentOutOfBagQueryIndices = new int[outOfBagQueriesCount];
int outOfBagQIdx = 0;
for (int q = 0; q < CompleteTrainingSet.NumQueries; q++)
{
if (!selectedTrainQueries[q])
{
int begin = CompleteTrainingSet.Boundaries[q];
int numDocuments = CompleteTrainingSet.Boundaries[q + 1] - begin;
for (int d = 0; d < numDocuments; d++)
{
outOfBagDocs[outOfBagSize] = begin + d;
outOfBagSize++;
}
currentOutOfBagQueryIndices[outOfBagQIdx] = q;
outOfBagQIdx++;
}
}
CurrentTrainPartition = new DocumentPartitioning(trainDocs, trainSize, MaxLeaves);
CurrentOutOfBagPartition = new DocumentPartitioning(outOfBagDocs, outOfBagSize, MaxLeaves);
CurrentTrainPartition.Initialize();
CurrentOutOfBagPartition.Initialize();
}
}
}
|