|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
[assembly: LoadableClass(typeof(TreeEnsembleCombiner), null, typeof(SignatureModelCombiner), "Fast Tree Model Combiner", "FastTreeCombiner")]
namespace Microsoft.ML.Trainers.FastTree
{
internal sealed class TreeEnsembleCombiner : IModelCombiner
{
private readonly IHost _host;
private readonly PredictionKind _kind;
public TreeEnsembleCombiner(IHostEnvironment env, PredictionKind kind)
{
_host = env.Register("TreeEnsembleCombiner");
switch (kind)
{
case PredictionKind.BinaryClassification:
case PredictionKind.Regression:
case PredictionKind.Ranking:
_kind = kind;
break;
default:
throw _host.ExceptUserArg(nameof(kind), $"Tree ensembles can be either of type {nameof(PredictionKind.BinaryClassification)}, " +
$"{nameof(PredictionKind.Regression)} or {nameof(PredictionKind.Ranking)}");
}
}
IPredictor IModelCombiner.CombineModels(IEnumerable<IPredictor> models)
{
_host.CheckValue(models, nameof(models));
var ensemble = new InternalTreeEnsemble();
int modelCount = 0;
int featureCount = -1;
bool binaryClassifier = false;
foreach (var model in models)
{
modelCount++;
var predictor = model;
_host.CheckValue(predictor, nameof(models), "One of the models is null");
var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
double paramA = 1;
if (calibrated != null)
_host.Check(calibrated.WeaklyTypedCalibrator is PlattCalibrator,
"Combining FastTree models can only be done when the models are calibrated with Platt calibrator");
predictor = calibrated.WeaklyTypedSubModel;
paramA = -((PlattCalibrator)calibrated.WeaklyTypedCalibrator).Slope;
var tree = predictor as TreeEnsembleModelParameters;
if (tree == null)
throw _host.Except("Model is not a tree ensemble");
foreach (var t in tree.TrainedEnsemble.Trees)
{
var bytes = new byte[t.SizeInBytes()];
int position = 0;
t.ToByteArray(bytes, ref position);
position = 0;
var tNew = new InternalRegressionTree(bytes, ref position);
if (paramA != 1)
{
for (int i = 0; i < tNew.NumLeaves; i++)
tNew.SetOutput(i, tNew.LeafValues[i] * paramA);
}
ensemble.AddTree(tNew);
}
if (modelCount == 1)
{
binaryClassifier = calibrated != null;
featureCount = tree.InputType.GetValueCount();
}
else
{
_host.Check((calibrated != null) == binaryClassifier, "Ensemble contains both calibrated and uncalibrated models");
_host.Check(featureCount == tree.InputType.GetValueCount(), "Found models with different number of features");
}
}
var scale = 1 / (double)modelCount;
foreach (var t in ensemble.Trees)
{
for (int i = 0; i < t.NumLeaves; i++)
t.SetOutput(i, t.LeafValues[i] * scale);
}
switch (_kind)
{
case PredictionKind.BinaryClassification:
if (!binaryClassifier)
return new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null);
var cali = new PlattCalibrator(_host, -1, 0);
var fastTreeModel = new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null);
return new FeatureWeightsCalibratedModelParameters<FastTreeBinaryModelParameters, PlattCalibrator>(_host, fastTreeModel, cali);
case PredictionKind.Regression:
return new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null);
case PredictionKind.Ranking:
return new FastTreeRankingModelParameters(_host, ensemble, featureCount, null);
default:
_host.Assert(false);
throw _host.ExceptNotSupp();
}
}
}
}
|