|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.Conversion;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Transforms;
using Microsoft.ML.TreePredictor;
[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(TreeEnsembleFeaturizerTransform), typeof(TreeEnsembleFeaturizerBindableMapper.Arguments),
typeof(SignatureBindableMapper), "Tree Ensemble Featurizer Mapper", TreeEnsembleFeaturizerBindableMapper.LoadNameShort)]
[assembly: LoadableClass(typeof(IDataScorerTransform), typeof(TreeEnsembleFeaturizerTransform), typeof(TreeEnsembleFeaturizerBindableMapper.Arguments),
typeof(SignatureDataScorer), "Tree Ensemble Featurizer Scorer", TreeEnsembleFeaturizerBindableMapper.LoadNameShort)]
[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(TreeEnsembleFeaturizerTransform), null, typeof(SignatureLoadModel),
"Tree Ensemble Featurizer Mapper", TreeEnsembleFeaturizerBindableMapper.LoaderSignature)]
[assembly: LoadableClass(TreeEnsembleFeaturizerTransform.TreeEnsembleSummary, typeof(IDataTransform), typeof(TreeEnsembleFeaturizerTransform),
typeof(TreeEnsembleFeaturizerTransform.Arguments), typeof(SignatureDataTransform),
TreeEnsembleFeaturizerTransform.UserName, TreeEnsembleFeaturizerBindableMapper.LoadNameShort, "TreeFeaturizationTransform")]
[assembly: LoadableClass(typeof(void), typeof(TreeFeaturize), null, typeof(SignatureEntryPointModule), "TreeFeaturize")]
namespace Microsoft.ML.Data
{
/// <summary>
/// A bindable mapper wrapper for tree ensembles, that creates a bound mapper with three outputs:
/// 1. A vector containing the individual tree outputs of the tree ensemble.
/// 2. An indicator vector for the leaves that the feature vector falls on in the tree ensemble.
/// 3. An indicator vector for the internal nodes on the paths that the feature vector falls on in the tree ensemble.
/// </summary>
internal sealed class TreeEnsembleFeaturizerBindableMapper : ISchemaBindableMapper, ICanSaveModel
{
/// <summary>
/// In addition to options inherited from <see cref="ScorerArgumentsBase"/>,
/// <see cref="Arguments"/> adds output columns' names of tree-based featurizer.
/// </summary>
public sealed class Arguments : ScorerArgumentsBase
{
/// <summary>
/// See <see cref="TreeEnsembleFeaturizationEstimatorBase.OptionsBase.TreesColumnName"/>.
/// </summary>
public string TreesColumnName;
/// <summary>
/// See <see cref="TreeEnsembleFeaturizationEstimatorBase.OptionsBase.LeavesColumnName"/>.
/// </summary>
public string LeavesColumnName;
/// <summary>
/// See <see cref="TreeEnsembleFeaturizationEstimatorBase.OptionsBase.PathsColumnName"/>.
/// </summary>
public string PathsColumnName;
}
private sealed class BoundMapper : ISchemaBoundRowMapper
{
public RoleMappedSchema InputRoleMappedSchema { get; }
public DataViewSchema InputSchema => InputRoleMappedSchema.Schema;
public DataViewSchema OutputSchema { get; }
public ISchemaBindableMapper Bindable => _owner;
private readonly TreeEnsembleFeaturizerBindableMapper _owner;
private readonly IExceptionContext _ectx;
/// <summary>
/// Feature vector to be mapped to tree-based features.
/// </summary>
private DataViewSchema.Column FeatureColumn => InputRoleMappedSchema.Feature.Value;
/// <summary>
/// The name of the column that stores the prediction values of all trees. Its type is a vector of <see cref="System.Single"/>
/// and the i-th vector element is the prediction value predicted by the i-th tree.
/// If <see cref="_treesColumnName"/> is <see langword="null"/>, this output column may not be generated.
/// </summary>
private readonly string _treesColumnName;
/// <summary>
/// The 0-1 encoding of all leaf nodes' IDs. Its type is a vector of <see cref="System.Single"/>. If the given feature
/// vector falls into the first leaf of the first tree, the first element in the 0-1 encoding would be 1.
/// If <see cref="_leavesColumnName"/> is <see langword="null"/>, this output column may not be generated.
/// </summary>
private readonly string _leavesColumnName;
/// <summary>
/// The 0-1 encoding of the paths to the leaves. If the path to the first tree's leaf is node 1 (2nd node in the first tree),
/// node 3 (4th node in the first tree), and node 5 (6th node in the first tree), the 2nd, 4th, and 6th element in that encoding
/// would be 1.
/// If <see cref="_pathsColumnName"/> is <see langword="null"/>, this output column may not be generated.
/// </summary>
private readonly string _pathsColumnName;
public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper owner, RoleMappedSchema schema,
string treesColumnName, string leavesColumnName, string pathsColumnName)
{
Contracts.AssertValue(ectx);
ectx.AssertValue(owner);
ectx.AssertValue(schema);
ectx.Assert(schema.Feature.HasValue);
_ectx = ectx;
_owner = owner;
InputRoleMappedSchema = schema;
// A vector containing the output of each tree on a given example.
var treeValueType = new VectorDataViewType(NumberDataViewType.Single, owner._ensemble.TrainedEnsemble.NumTrees);
// An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
// ends up in all the trees in the ensemble.
var leafIdType = new VectorDataViewType(NumberDataViewType.Single, owner._totalLeafCount);
// An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
// the paths of the example in all the trees in the ensemble.
// The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
// and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
// plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
// which means that #internal = #leaf - 1.
// Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
var pathIdType = new VectorDataViewType(NumberDataViewType.Single, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees);
// Start creating output schema with types derived above.
var schemaBuilder = new DataViewSchema.Builder();
_treesColumnName = treesColumnName;
if (treesColumnName != null)
{
// Metadata of tree values.
var treeIdMetadataBuilder = new DataViewSchema.Annotations.Builder();
treeIdMetadataBuilder.Add(AnnotationUtils.Kinds.SlotNames, AnnotationUtils.GetNamesType(treeValueType.Size),
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetTreeSlotNames);
// Add the column of trees' output values
schemaBuilder.AddColumn(treesColumnName, treeValueType, treeIdMetadataBuilder.ToAnnotations());
}
_leavesColumnName = leavesColumnName;
if (leavesColumnName != null)
{
// Metadata of leaf IDs.
var leafIdMetadataBuilder = new DataViewSchema.Annotations.Builder();
leafIdMetadataBuilder.Add(AnnotationUtils.Kinds.SlotNames, AnnotationUtils.GetNamesType(leafIdType.Size),
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetLeafSlotNames);
leafIdMetadataBuilder.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, (ref bool value) => value = true);
// Add the column of leaves' IDs where the input example reaches.
schemaBuilder.AddColumn(leavesColumnName, leafIdType, leafIdMetadataBuilder.ToAnnotations());
}
_pathsColumnName = pathsColumnName;
if (pathsColumnName != null)
{
// Metadata of path IDs.
var pathIdMetadataBuilder = new DataViewSchema.Annotations.Builder();
pathIdMetadataBuilder.Add(AnnotationUtils.Kinds.SlotNames, AnnotationUtils.GetNamesType(pathIdType.Size),
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetPathSlotNames);
pathIdMetadataBuilder.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, (ref bool value) => value = true);
// Add the column of encoded paths which the input example passes.
schemaBuilder.AddColumn(pathsColumnName, pathIdType, pathIdMetadataBuilder.ToAnnotations());
}
OutputSchema = schemaBuilder.ToSchema();
}
DataViewRow ISchemaBoundRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
{
_ectx.CheckValue(input, nameof(input));
_ectx.CheckValue(activeColumns, nameof(activeColumns));
return new SimpleRow(OutputSchema, input, CreateGetters(input, activeColumns));
}
private Delegate[] CreateGetters(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
{
_ectx.AssertValue(input);
_ectx.AssertValue(activeColumns);
var delegates = new List<Delegate>();
var activeIndices = activeColumns.Select(c => c.Index);
var state = new State(_ectx, input, _owner._ensemble, _owner._totalLeafCount, FeatureColumn.Index);
// Get the tree value getter.
if (_treesColumnName != null)
{
ValueGetter<VBuffer<float>> fn = state.GetTreeValues;
if (activeIndices.Contains(OutputSchema[_treesColumnName].Index))
delegates.Add(fn);
else
delegates.Add(null);
}
// Get the leaf indicator getter.
if (_leavesColumnName != null)
{
ValueGetter<VBuffer<float>> fn = state.GetLeafIds;
if (activeIndices.Contains(OutputSchema[_leavesColumnName].Index))
delegates.Add(fn);
else
delegates.Add(null);
}
// Get the path indicators getter.
if (_pathsColumnName != null)
{
ValueGetter<VBuffer<float>> fn = state.GetPathIds;
if (activeIndices.Contains(OutputSchema[_pathsColumnName].Index))
delegates.Add(fn);
else
delegates.Add(null);
}
return delegates.ToArray();
}
private sealed class State
{
private readonly IExceptionContext _ectx;
private readonly DataViewRow _input;
private readonly TreeEnsembleModelParameters _ensemble;
private readonly int _numTrees;
private readonly int _numLeaves;
private VBuffer<float> _src;
private readonly ValueGetter<VBuffer<float>> _featureGetter;
private long _cachedPosition;
private readonly int[] _leafIds;
private readonly List<int>[] _pathIds;
private BufferBuilder<float> _leafIdBuilder;
private BufferBuilder<float> _pathIdBuilder;
private long _cachedLeafBuilderPosition;
private long _cachedPathBuilderPosition;
public State(IExceptionContext ectx, DataViewRow input, TreeEnsembleModelParameters ensemble, int numLeaves, int featureIndex)
{
Contracts.AssertValue(ectx);
_ectx = ectx;
_ectx.AssertValue(input);
_ectx.AssertValue(ensemble);
_ectx.Assert(ensemble.TrainedEnsemble.NumTrees > 0);
_input = input;
_ensemble = ensemble;
_numTrees = _ensemble.TrainedEnsemble.NumTrees;
_numLeaves = numLeaves;
_src = default(VBuffer<float>);
_featureGetter = input.GetGetter<VBuffer<float>>(input.Schema[featureIndex]);
_cachedPosition = -1;
_leafIds = new int[_numTrees];
_pathIds = new List<int>[_numTrees];
for (int i = 0; i < _numTrees; i++)
_pathIds[i] = new List<int>();
_cachedLeafBuilderPosition = -1;
_cachedPathBuilderPosition = -1;
}
public void GetTreeValues(ref VBuffer<float> dst)
{
EnsureCachedPosition();
var editor = VBufferEditor.Create(ref dst, _numTrees);
for (int i = 0; i < _numTrees; i++)
editor.Values[i] = _ensemble.GetLeafValue(i, _leafIds[i]);
dst = editor.Commit();
}
public void GetLeafIds(ref VBuffer<float> dst)
{
EnsureCachedPosition();
_ectx.Assert(_input.Position >= 0);
_ectx.Assert(_cachedPosition == _input.Position);
if (_cachedLeafBuilderPosition != _input.Position)
{
if (_leafIdBuilder == null)
_leafIdBuilder = BufferBuilder<float>.CreateDefault();
_leafIdBuilder.Reset(_numLeaves, false);
var offset = 0;
var trees = ((ITreeEnsemble)_ensemble).GetTrees();
for (int i = 0; i < trees.Length; i++)
{
_leafIdBuilder.AddFeature(offset + _leafIds[i], 1);
offset += trees[i].NumLeaves;
}
_cachedLeafBuilderPosition = _input.Position;
}
_ectx.AssertValue(_leafIdBuilder);
_leafIdBuilder.GetResult(ref dst);
}
public void GetPathIds(ref VBuffer<float> dst)
{
EnsureCachedPosition();
_ectx.Assert(_input.Position >= 0);
_ectx.Assert(_cachedPosition == _input.Position);
if (_cachedPathBuilderPosition != _input.Position)
{
if (_pathIdBuilder == null)
_pathIdBuilder = BufferBuilder<float>.CreateDefault();
var trees = ((ITreeEnsemble)_ensemble).GetTrees();
_pathIdBuilder.Reset(_numLeaves - _numTrees, dense: false);
var offset = 0;
for (int i = 0; i < _numTrees; i++)
{
var numNodes = trees[i].NumLeaves - 1;
var nodes = _pathIds[i];
_ectx.AssertValue(nodes);
for (int j = 0; j < nodes.Count; j++)
{
var node = nodes[j];
_ectx.Assert(0 <= node && node < numNodes);
_pathIdBuilder.AddFeature(offset + node, 1);
}
offset += numNodes;
}
_cachedPathBuilderPosition = _input.Position;
}
_ectx.AssertValue(_pathIdBuilder);
_pathIdBuilder.GetResult(ref dst);
}
private void EnsureCachedPosition()
{
_ectx.Check(_input.Position >= 0, RowCursorUtils.FetchValueStateError);
if (_cachedPosition != _input.Position)
{
_featureGetter(ref _src);
_ectx.Assert(Utils.Size(_leafIds) == _numTrees);
_ectx.Assert(Utils.Size(_pathIds) == _numTrees);
for (int i = 0; i < _numTrees; i++)
_leafIds[i] = _ensemble.GetLeaf(i, in _src, ref _pathIds[i]);
_cachedPosition = _input.Position;
}
}
}
public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
{
yield return RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name);
}
/// <summary>
/// Given a set of columns, return the input columns that are needed to generate those output columns.
/// </summary>
public IEnumerable<DataViewSchema.Column> GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns)
{
if (dependingColumns.Count() == 0)
return Enumerable.Empty<DataViewSchema.Column>();
return Enumerable.Repeat(FeatureColumn, 1);
}
}
public const string LoadNameShort = "TreeFeat";
public const string LoaderSignature = "TreeEnsembleMapper";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "TREEMAPR",
// verWrittenCur: 0x00010001, // Initial
// verWrittenCur: 0x00010002, // Add _defaultValueForMissing
verWrittenCur: 0x00010003, // Add output column names (_treesColumnName, _leavesColumnName, _pathsColumnName)
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010002,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TreeEnsembleFeaturizerBindableMapper).Assembly.FullName);
}
private readonly IHost _host;
private readonly TreeEnsembleModelParameters _ensemble;
private readonly int _totalLeafCount;
private readonly string _treesColumnName;
private readonly string _leavesColumnName;
private readonly string _pathsColumnName;
public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, Arguments args, IPredictor predictor)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(LoaderSignature);
_host.CheckValue(args, nameof(args));
_host.CheckValue(predictor, nameof(predictor));
// Store output columns specified by the user.
_treesColumnName = args.TreesColumnName;
_leavesColumnName = args.LeavesColumnName;
_pathsColumnName = args.PathsColumnName;
// This function accepts models trained by FastTreeTrainer family. There are four types that "predictor" can be.
// 1. CalibratedPredictorBase<FastTreeBinaryModelParameters, PlattCalibrator>
// 2. FastTreeRankingModelParameters
// 3. FastTreeRegressionModelParameters
// 4. FastTreeTweedieModelParameters
// Only (1) needs a special cast right below because all others are derived from TreeEnsembleModelParameters.
if (predictor is CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator> calibrated)
predictor = calibrated.SubModel;
_ensemble = predictor as TreeEnsembleModelParameters;
_host.Check(_ensemble != null, "Predictor in model file does not have compatible type");
_totalLeafCount = CountLeaves(_ensemble);
}
public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(LoaderSignature);
_host.AssertValue(ctx);
ctx.CheckAtModel(GetVersionInfo());
// *** Binary format ***
// ensemble
// string: treesColumnName
// string: leavesColumnName
// string: pathsColumnName
ctx.LoadModel<TreeEnsembleModelParameters, SignatureLoadModel>(env, out _ensemble, "Ensemble");
_totalLeafCount = CountLeaves(_ensemble);
_treesColumnName = ctx.LoadStringOrNull();
_leavesColumnName = ctx.LoadStringOrNull();
_pathsColumnName = ctx.LoadStringOrNull();
}
void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
// ensemble
// string: treesColumnName
// string: leavesColumnName
// string: pathsColumnName
_host.AssertValue(_ensemble);
ctx.SaveModel(_ensemble, "Ensemble");
ctx.SaveStringOrNull(_treesColumnName);
ctx.SaveStringOrNull(_leavesColumnName);
ctx.SaveStringOrNull(_pathsColumnName);
}
private static int CountLeaves(TreeEnsembleModelParameters ensemble)
{
Contracts.AssertValue(ensemble);
var trees = ((ITreeEnsemble)ensemble).GetTrees();
var numTrees = trees.Length;
var totalLeafCount = 0;
for (int i = 0; i < numTrees; i++)
totalLeafCount += trees[i].NumLeaves;
return totalLeafCount;
}
private void GetTreeSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
{
var numTrees = _ensemble.TrainedEnsemble.NumTrees;
var editor = VBufferEditor.Create(ref dst, numTrees);
for (int t = 0; t < numTrees; t++)
editor.Values[t] = string.Format("Tree{0:000}", t).AsMemory();
dst = editor.Commit();
}
private void GetLeafSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
{
var numTrees = _ensemble.TrainedEnsemble.NumTrees;
var editor = VBufferEditor.Create(ref dst, _totalLeafCount);
int i = 0;
int t = 0;
foreach (var tree in ((ITreeEnsemble)_ensemble).GetTrees())
{
for (int l = 0; l < tree.NumLeaves; l++)
editor.Values[i++] = string.Format("Tree{0:000}Leaf{1:000}", t, l).AsMemory();
t++;
}
_host.Assert(i == _totalLeafCount);
dst = editor.Commit();
}
private void GetPathSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
{
var numTrees = _ensemble.TrainedEnsemble.NumTrees;
var totalNodeCount = _totalLeafCount - numTrees;
var editor = VBufferEditor.Create(ref dst, totalNodeCount);
int i = 0;
int t = 0;
foreach (var tree in ((ITreeEnsemble)_ensemble).GetTrees())
{
var numLeaves = tree.NumLeaves;
for (int l = 0; l < tree.NumLeaves - 1; l++)
editor.Values[i++] = string.Format("Tree{0:000}Node{1:000}", t, l).AsMemory();
t++;
}
_host.Assert(i == totalNodeCount);
dst = editor.Commit();
}
ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
{
Contracts.AssertValue(env);
env.AssertValue(schema);
env.CheckParam(schema.Feature != null, nameof(schema), "Need a feature column");
return new BoundMapper(env, this, schema, _treesColumnName, _leavesColumnName, _pathsColumnName);
}
}
/// <include file='doc.xml' path='doc/members/member[@name="TreeEnsembleFeaturizerTransform"]'/>
[BestFriend]
internal static class TreeEnsembleFeaturizerTransform
{
#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
public sealed class Arguments : TrainAndScoreTransformer.ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Trainer to use", ShortName = "tr", NullName = "<None>", SortOrder = 1, SignatureType = typeof(SignatureTreeEnsembleTrainer))]
public IComponentFactory<ITrainer> Trainer;
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Predictor model file used in scoring",
ShortName = "in", SortOrder = 2)]
public string TrainedModelFile;
[Argument(ArgumentType.AtMostOnce, HelpText = "Output column: The suffix to append to the default column names",
ShortName = "ex", SortOrder = 101)]
public string Suffix;
[Argument(ArgumentType.AtMostOnce, HelpText = "If specified, determines the permutation seed for applying this featurizer to a multiclass problem.",
ShortName = "lps", SortOrder = 102)]
public int LabelPermutationSeed;
}
/// <summary>
/// REVIEW: Ideally we should have only one arguments class by using IComponentFactory for the model.
/// For now it probably warrants a REVIEW comment here in case we'd like to merge these two arguments in the future.
/// Also, it might be worthwhile to extract the common arguments to a base class.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.IFeaturizerInput))]
public sealed class ArgumentsForEntryPoint : TransformInputBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Output column: The suffix to append to the default column names",
ShortName = "ex", SortOrder = 101)]
public string Suffix;
[Argument(ArgumentType.AtMostOnce, HelpText = "If specified, determines the permutation seed for applying this featurizer to a multiclass problem.",
ShortName = "lps", SortOrder = 102)]
public int LabelPermutationSeed;
[Argument(ArgumentType.Required, HelpText = "Trainer to use", SortOrder = 10, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public PredictorModel PredictorModel;
}
#pragma warning restore CS0649
internal const string TreeEnsembleSummary =
"Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector " +
"to three outputs: 1. A vector containing the individual tree outputs of the tree ensemble. " +
"2. A vector indicating the leaves that the feature vector falls on in the tree ensemble. " +
"3. A vector indicating the paths that the feature vector falls on in the tree ensemble. " +
"If a both a model file and a trainer are specified - will use the model file. If neither are specified, " +
"will train a default FastTree model. " +
"This can handle key labels by training a regression model towards their optionally permuted indices.";
internal const string UserName = "Tree Ensemble Featurization Transform";
// Factory method for SignatureDataScorer.
private static IDataScorerTransform Create(IHostEnvironment env,
TreeEnsembleFeaturizerBindableMapper.Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
{
return new GenericScorer(env, args, data, mapper, trainSchema);
}
// Factory method for SignatureBindableMapper.
private static ISchemaBindableMapper Create(IHostEnvironment env,
TreeEnsembleFeaturizerBindableMapper.Arguments args, IPredictor predictor)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
env.CheckValue(predictor, nameof(predictor));
return new TreeEnsembleFeaturizerBindableMapper(env, args, predictor);
}
// Factory method for SignatureLoadModel.
private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx)
{
return new TreeEnsembleFeaturizerBindableMapper(env, ctx);
}
// Factory method for SignatureDataTransform.
private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("Tree Featurizer Transform");
host.CheckValue(args, nameof(args));
host.CheckValue(input, nameof(input));
host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile),
"Please specify either a trainer or an input model file.");
host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column");
IDataTransform xf;
using (var ch = host.Start("Create Tree Ensemble Scorer"))
{
var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
{
Suffix = args.Suffix,
TreesColumnName = "Trees",
LeavesColumnName = "Leaves",
PathsColumnName = "Paths"
};
if (!string.IsNullOrWhiteSpace(args.TrainedModelFile))
{
if (args.Trainer != null)
ch.Warning("Both an input model and a trainer were specified. Using the model file.");
ch.Trace("Loading model");
IPredictor predictor;
using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read, FileShare.Read))
using (var rep = RepositoryReader.Open(strm, ch))
ModelLoadContext.LoadModel<IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor);
ch.Trace("Creating scorer");
var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args);
Contracts.Assert(data.Schema.Feature.HasValue);
// Make sure that the given predictor has the correct number of input features.
if (predictor is IWeaklyTypedCalibratedModelParameters calibrated)
predictor = calibrated.WeaklyTypedSubModel;
// Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should
// be non-null.
var vm = predictor as IValueMapper;
ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type");
if (vm.InputType.GetVectorSize() != data.Schema.Feature.Value.Type.GetVectorSize())
{
throw ch.ExceptUserArg(nameof(args.TrainedModelFile),
"Predictor in model file expects {0} features, but data has {1} features",
vm.InputType.GetVectorSize(), data.Schema.Feature.Value.Type.GetVectorSize());
}
ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
var bound = bindable.Bind(env, data.Schema);
xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
}
else
{
ch.AssertValue(args.Trainer);
ch.Trace("Creating TrainAndScoreTransform");
var trainScoreArgs = new TrainAndScoreTransformer.Arguments();
args.CopyTo(trainScoreArgs);
trainScoreArgs.Trainer = args.Trainer;
trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(
(e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema));
var mapperFactory = ComponentFactoryUtils.CreateFromFunction<IPredictor, ISchemaBindableMapper>(
(e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor));
var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed);
var scoreXf = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory);
if (input == labelInput)
return scoreXf;
return (IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput);
}
}
return xf;
}
public static IDataTransform CreateForEntryPoint(IHostEnvironment env, ArgumentsForEntryPoint args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("Tree Featurizer Transform");
host.CheckValue(args, nameof(args));
host.CheckValue(input, nameof(input));
host.CheckUserArg(args.PredictorModel != null, nameof(args.PredictorModel), "Please specify a predictor model.");
using (var ch = host.Start("Create Tree Ensemble Scorer"))
{
var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
{
Suffix = args.Suffix,
TreesColumnName = "Trees",
LeavesColumnName = "Leaves",
PathsColumnName = "Paths"
};
var predictor = args.PredictorModel.Predictor;
ch.Trace("Prepare data");
RoleMappedData data = null;
args.PredictorModel.PrepareData(env, input, out data, out var predictor2);
ch.AssertValue(data);
ch.Assert(data.Schema.Feature.HasValue);
ch.Assert(predictor == predictor2);
// Make sure that the given predictor has the correct number of input features.
if (predictor is CalibratedModelParametersBase<IPredictorProducing<float>, Calibrators.ICalibrator> calibratedModelParametersBase)
predictor = calibratedModelParametersBase.SubModel;
// Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should
// be non-null.
var vm = predictor as IValueMapper;
ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type");
if (data != null && vm.InputType.GetVectorSize() != data.Schema.Feature.Value.Type.GetVectorSize())
{
throw ch.ExceptUserArg(nameof(args.PredictorModel),
"Predictor expects {0} features, but data has {1} features",
vm.InputType.GetVectorSize(), data.Schema.Feature.Value.Type.GetVectorSize());
}
ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
var bound = bindable.Bind(env, data.Schema);
return new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema);
}
}
private static IDataView AppendFloatMapper<TInput>(IHostEnvironment env, IChannel ch, IDataView input,
string col, KeyDataViewType type, int seed)
{
// Any key is convertible to ulong, so rather than add special case handling for all possible
// key-types we just upfront convert it to the most general type (ulong) and work from there.
KeyDataViewType dstType = new KeyDataViewType(typeof(ulong), type.Count);
bool identity;
var converter = Conversions.DefaultInstance.GetStandardConversion<TInput, ulong>(type, dstType, out identity);
var isNa = Conversions.DefaultInstance.GetIsNAPredicate<TInput>(type);
ValueMapper<TInput, Single> mapper;
if (seed == 0)
{
mapper =
(in TInput src, ref Single dst) =>
{
//Attention: This method is called from multiple threads.
//Do not move the temp variable outside this method.
//If you do, the variable is shared between the threads and results in a race condition.
ulong temp = 0;
if (isNa(in src))
{
dst = Single.NaN;
return;
}
converter(in src, ref temp);
dst = (Single)temp - 1;
};
}
else
{
ch.Check(type.Count > 0, "Label must be of known cardinality.");
int[] permutation = Utils.GetRandomPermutation(RandomUtils.Create(seed), type.GetCountAsInt32(env));
mapper =
(in TInput src, ref Single dst) =>
{
//Attention: This method is called from multiple threads.
//Do not move the temp variable outside this method.
//If you do, the variable is shared between the threads and results in a race condition.
ulong temp = 0;
if (isNa(in src))
{
dst = Single.NaN;
return;
}
converter(in src, ref temp);
dst = (Single)permutation[(int)(temp - 1)];
};
}
return LambdaColumnMapper.Create(env, "Key to Float Mapper", input, col, col, type, NumberDataViewType.Single, mapper);
}
private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch, IDataView input, string labelName, int labelPermutationSeed)
{
Contracts.AssertValue(env);
env.AssertValue(ch);
ch.AssertValue(input);
ch.AssertNonWhiteSpace(labelName);
var col = input.Schema.GetColumnOrNull(labelName);
if (!col.HasValue)
throw ch.ExceptSchemaMismatch(nameof(input), "label", labelName);
DataViewType labelType = col.Value.Type;
if (!(labelType is KeyDataViewType))
{
if (labelPermutationSeed != 0)
ch.Warning(
"labelPermutationSeed != 0 only applies on a multi-class learning problem when the label type is a key.");
return input;
}
return Utils.MarshalInvoke(AppendFloatMapper<int>, labelType.RawType, env, ch, input, labelName, (KeyDataViewType)labelType,
labelPermutationSeed);
}
}
internal static partial class TreeFeaturize
{
#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
[TlcModule.EntryPoint(Name = "Transforms.TreeLeafFeaturizer",
Desc = TreeEnsembleFeaturizerTransform.TreeEnsembleSummary,
UserName = TreeEnsembleFeaturizerTransform.UserName,
ShortName = TreeEnsembleFeaturizerBindableMapper.LoadNameShort)]
public static CommonOutputs.TransformOutput Featurizer(IHostEnvironment env, TreeEnsembleFeaturizerTransform.ArgumentsForEntryPoint input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TreeFeaturizerTransform");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
var xf = TreeEnsembleFeaturizerTransform.CreateForEntryPoint(env, input, input.Data);
return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf };
}
#pragma warning restore CS0649
}
}
|