|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(typeof(GamModelParametersBase.VisualizationCommand), typeof(GamModelParametersBase.VisualizationCommand.Arguments), typeof(SignatureCommand),
"GAM Visualization Command", GamModelParametersBase.VisualizationCommand.LoadName, "gamviz", DocName = "command/GamViz.md")]
namespace Microsoft.ML.Trainers.FastTree
{
/// <summary>
/// The base class for GAM Model Parameters.
/// </summary>
public abstract class GamModelParametersBase : ModelParametersBase<float>, IValueMapper, ICalculateFeatureContribution,
IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary, ICanSaveInIniFormat
{
/// <summary>
/// The model intercept. Also known as bias or mean effect.
/// </summary>
public readonly double Bias;
/// <summary>
/// The number of shape functions used in the model.
/// </summary>
public readonly int NumberOfShapeFunctions;
private readonly double[][] _binUpperBounds;
private readonly double[][] _binEffects;
private readonly VectorDataViewType _inputType;
private readonly DataViewType _outputType;
// These would be the bins for a totally sparse input.
private readonly int[] _binsAtAllZero;
// The output value for all zeros
private readonly double _valueAtAllZero;
private readonly int[] _shapeToInputMap;
private readonly int _numInputFeatures;
private readonly Dictionary<int, int> _inputFeatureToShapeFunctionMap;
DataViewType IValueMapper.InputType => _inputType;
DataViewType IValueMapper.OutputType => _outputType;
/// <summary>
/// Used to determine the contribution of each feature to the score of an example by <see cref="FeatureContributionCalculatingTransformer"/>.
/// For Generalized Additive Models (GAM), the contribution of a feature is equal to the shape function for the given feature evaluated at
/// the feature value.
/// </summary>
FeatureContributionCalculator ICalculateFeatureContribution.FeatureContributionCalculator => new FeatureContributionCalculator(this);
private protected GamModelParametersBase(IHostEnvironment env, string name,
double[][] binUpperBounds, double[][] binEffects, double intercept, int numInputFeatures = -1, int[] shapeToInputMap = null)
: base(env, name)
{
Host.CheckValue(binEffects, nameof(binEffects), "May not be null.");
Host.CheckValue(binUpperBounds, nameof(binUpperBounds), "May not be null.");
Host.CheckParam(binUpperBounds.Length == binEffects.Length, nameof(binUpperBounds), "Must have same number of features as binEffects");
Host.CheckParam(binEffects.Length > 0, nameof(binEffects), "Must have at least one entry");
Host.CheckParam(numInputFeatures == -1 || numInputFeatures > 0, nameof(numInputFeatures), "Must be greater than zero");
Host.CheckParam(shapeToInputMap == null || shapeToInputMap.Length == binEffects.Length, nameof(shapeToInputMap), "Must have same number of features as binEffects");
// Define the model basics
Bias = intercept;
_binUpperBounds = binUpperBounds;
_binEffects = binEffects;
NumberOfShapeFunctions = binEffects.Length;
// For sparse inputs we have a fast lookup
_binsAtAllZero = new int[NumberOfShapeFunctions];
_valueAtAllZero = 0;
// Walk through each feature and perform checks / updates
for (int i = 0; i < NumberOfShapeFunctions; i++)
{
// Check data validity
Host.CheckValue(binEffects[i], nameof(binEffects), "Array contained null entries");
Host.CheckParam(binUpperBounds[i].Length == binEffects[i].Length, nameof(binEffects), "Array contained wrong number of effect values");
Host.CheckParam(Utils.IsMonotonicallyIncreasing(binUpperBounds[i]), nameof(binUpperBounds), "Array must be monotonically increasing");
// Update the value at zero
_valueAtAllZero += GetBinEffect(i, 0, out _binsAtAllZero[i]);
}
// Define the sparse mappings from/to input to/from shape functions
_shapeToInputMap = shapeToInputMap;
if (_shapeToInputMap == null)
_shapeToInputMap = Utils.GetIdentityPermutation(NumberOfShapeFunctions);
_numInputFeatures = numInputFeatures;
if (_numInputFeatures == -1)
_numInputFeatures = NumberOfShapeFunctions;
_inputFeatureToShapeFunctionMap = new Dictionary<int, int>(_shapeToInputMap.Length);
for (int i = 0; i < _shapeToInputMap.Length; i++)
{
Host.CheckParam(0 <= _shapeToInputMap[i] && _shapeToInputMap[i] < _numInputFeatures, nameof(_shapeToInputMap), "Contains out of range feature value");
Host.CheckParam(!_inputFeatureToShapeFunctionMap.ContainsValue(_shapeToInputMap[i]), nameof(_shapeToInputMap), "Contains duplicate mappings");
_inputFeatureToShapeFunctionMap[_shapeToInputMap[i]] = i;
}
_inputType = new VectorDataViewType(NumberDataViewType.Single, _numInputFeatures);
_outputType = NumberDataViewType.Single;
}
private protected GamModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
: base(env, name)
{
Host.CheckValue(ctx, nameof(ctx));
BinaryReader reader = ctx.Reader;
NumberOfShapeFunctions = reader.ReadInt32();
Host.CheckDecode(NumberOfShapeFunctions >= 0);
_numInputFeatures = reader.ReadInt32();
Host.CheckDecode(_numInputFeatures >= 0);
Bias = reader.ReadDouble();
if (ctx.Header.ModelVerWritten == 0x00010001)
using (var ch = env.Start("GamWarningChannel"))
ch.Warning("GAMs models written prior to ML.NET 0.6 are loaded with an incorrect Intercept. For these models, subtract the value of the intercept from the prediction.");
_binEffects = new double[NumberOfShapeFunctions][];
_binUpperBounds = new double[NumberOfShapeFunctions][];
_binsAtAllZero = new int[NumberOfShapeFunctions];
for (int i = 0; i < NumberOfShapeFunctions; i++)
{
_binEffects[i] = reader.ReadDoubleArray();
Host.CheckDecode(Utils.Size(_binEffects[i]) >= 1);
}
for (int i = 0; i < NumberOfShapeFunctions; i++)
{
_binUpperBounds[i] = reader.ReadDoubleArray(_binEffects[i].Length);
_valueAtAllZero += GetBinEffect(i, 0, out _binsAtAllZero[i]);
}
int len = reader.ReadInt32();
Host.CheckDecode(len >= 0);
_inputFeatureToShapeFunctionMap = new Dictionary<int, int>(len);
_shapeToInputMap = Utils.CreateArray(NumberOfShapeFunctions, -1);
for (int i = 0; i < len; i++)
{
int key = reader.ReadInt32();
Host.CheckDecode(0 <= key && key < _numInputFeatures);
int val = reader.ReadInt32();
Host.CheckDecode(0 <= val && val < NumberOfShapeFunctions);
Host.CheckDecode(!_inputFeatureToShapeFunctionMap.ContainsKey(key));
Host.CheckDecode(_shapeToInputMap[val] == -1);
_inputFeatureToShapeFunctionMap[key] = val;
_shapeToInputMap[val] = key;
}
_inputType = new VectorDataViewType(NumberDataViewType.Single, _numInputFeatures);
_outputType = NumberDataViewType.Single;
}
private protected override void SaveCore(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.Writer.Write(NumberOfShapeFunctions);
Host.Assert(NumberOfShapeFunctions >= 0);
ctx.Writer.Write(_numInputFeatures);
Host.Assert(_numInputFeatures >= 0);
ctx.Writer.Write(Bias);
for (int i = 0; i < NumberOfShapeFunctions; i++)
ctx.Writer.WriteDoubleArray(_binEffects[i]);
int diff = _binEffects.Sum(e => e.Take(e.Length - 1).Select((ef, i) => ef != e[i + 1] ? 1 : 0).Sum());
int bound = _binEffects.Sum(e => e.Length - 1);
for (int i = 0; i < NumberOfShapeFunctions; i++)
{
ctx.Writer.WriteDoublesNoCount(_binUpperBounds[i]);
Host.Assert(_binUpperBounds[i].Length == _binEffects[i].Length);
}
ctx.Writer.Write(_inputFeatureToShapeFunctionMap.Count);
foreach (KeyValuePair<int, int> kvp in _inputFeatureToShapeFunctionMap)
{
ctx.Writer.Write(kvp.Key);
ctx.Writer.Write(kvp.Value);
}
}
ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
{
Host.Check(typeof(TIn) == typeof(VBuffer<float>), "Input type does not match.");
Host.Check(typeof(TOut) == typeof(float), "Output type does not match.");
ValueMapper<VBuffer<float>, float> del = Map;
return (ValueMapper<TIn, TOut>)(Delegate)del;
}
private void Map(in VBuffer<float> features, ref float response)
{
Host.CheckParam(features.Length == _numInputFeatures, nameof(features), "Bad length of input");
double value = Bias;
var featuresValues = features.GetValues();
if (features.IsDense)
{
for (int i = 0; i < featuresValues.Length; ++i)
{
if (_inputFeatureToShapeFunctionMap.TryGetValue(i, out int j))
value += GetBinEffect(j, featuresValues[i]);
}
}
else
{
var featuresIndices = features.GetIndices();
// Add in the precomputed results for all features
value += _valueAtAllZero;
for (int i = 0; i < featuresValues.Length; ++i)
{
if (_inputFeatureToShapeFunctionMap.TryGetValue(featuresIndices[i], out int j))
// Add the value and subtract the value at zero that was previously accounted for
value += GetBinEffect(j, featuresValues[i]) - GetBinEffect(j, 0);
}
}
response = (float)value;
}
internal double GetFeatureBinsAndScore(in VBuffer<float> features, int[] bins)
{
Host.CheckParam(features.Length == _numInputFeatures, nameof(features));
Host.CheckParam(Utils.Size(bins) == NumberOfShapeFunctions, nameof(bins));
double value = Bias;
var featuresValues = features.GetValues();
if (features.IsDense)
{
for (int i = 0; i < featuresValues.Length; ++i)
{
if (_inputFeatureToShapeFunctionMap.TryGetValue(i, out int j))
value += GetBinEffect(j, featuresValues[i], out bins[j]);
}
}
else
{
var featuresIndices = features.GetIndices();
// Add in the precomputed results for all features
value += _valueAtAllZero;
Array.Copy(_binsAtAllZero, bins, NumberOfShapeFunctions);
// Update the results for features we have
for (int i = 0; i < featuresValues.Length; ++i)
{
if (_inputFeatureToShapeFunctionMap.TryGetValue(featuresIndices[i], out int j))
// Add the value and subtract the value at zero that was previously accounted for
value += GetBinEffect(j, featuresValues[i], out bins[j]) - GetBinEffect(j, 0);
}
}
return value;
}
private double GetBinEffect(int featureIndex, double featureValue)
{
Host.Assert(0 <= featureIndex && featureIndex < NumberOfShapeFunctions, "Index out of range.");
int index = Algorithms.FindFirstGE(_binUpperBounds[featureIndex], featureValue);
return _binEffects[featureIndex][index];
}
private double GetBinEffect(int featureIndex, double featureValue, out int binIndex)
{
Host.Check(0 <= featureIndex && featureIndex < NumberOfShapeFunctions, "Index out of range.");
binIndex = Algorithms.FindFirstGE(_binUpperBounds[featureIndex], featureValue);
return _binEffects[featureIndex][binIndex];
}
/// <summary>
/// Get the bin upper bounds for each feature.
/// </summary>
/// <param name="featureIndex">The index of the feature (in the training vector) to get.</param>
/// <returns>The bin upper bounds. May be zero length if this feature has no bins.</returns>
public IReadOnlyList<double> GetBinUpperBounds(int featureIndex)
{
Host.Check(0 <= featureIndex && featureIndex < NumberOfShapeFunctions, "Index out of range.");
if (!_inputFeatureToShapeFunctionMap.TryGetValue(featureIndex, out int j))
return new double[0];
var binUpperBounds = new double[_binUpperBounds[j].Length];
_binUpperBounds[j].CopyTo(binUpperBounds, 0);
return binUpperBounds;
}
/// <summary>
/// Get all the bin upper bounds.
/// </summary>
[BestFriend]
internal double[][] GetBinUpperBounds()
{
double[][] binUpperBounds = new double[NumberOfShapeFunctions][];
for (int i = 0; i < NumberOfShapeFunctions; i++)
{
if (_inputFeatureToShapeFunctionMap.TryGetValue(i, out int j))
{
binUpperBounds[i] = new double[_binUpperBounds[j].Length];
_binUpperBounds[j].CopyTo(binUpperBounds[i], 0);
}
else
{
binUpperBounds[i] = new double[0];
}
}
return binUpperBounds;
}
/// <summary>
/// Get the binned weights for each feature.
/// </summary>
/// <param name="featureIndex">The index of the feature (in the training vector) to get.</param>
/// <returns>The binned effects for each feature. May be zero length if this feature has no bins.</returns>
public IReadOnlyList<double> GetBinEffects(int featureIndex)
{
Host.Check(0 <= featureIndex && featureIndex < NumberOfShapeFunctions, "Index out of range.");
if (!_inputFeatureToShapeFunctionMap.TryGetValue(featureIndex, out int j))
return new double[0];
var binEffects = new double[_binEffects[j].Length];
_binEffects[j].CopyTo(binEffects, 0);
return binEffects;
}
/// <summary>
/// Get all the binned effects.
/// </summary>
[BestFriend]
internal double[][] GetBinEffects()
{
double[][] binEffects = new double[NumberOfShapeFunctions][];
for (int i = 0; i < NumberOfShapeFunctions; i++)
{
if (_inputFeatureToShapeFunctionMap.TryGetValue(i, out int j))
{
binEffects[i] = new double[_binEffects[j].Length];
_binEffects[j].CopyTo(binEffects[i], 0);
}
else
{
binEffects[i] = new double[0];
}
}
return binEffects;
}
void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
{
Host.CheckValue(writer, nameof(writer), "writer must not be null.");
Host.CheckValueOrNull(schema);
writer.WriteLine("\xfeffFeature index table"); // add BOM to tell excel this is UTF-8
writer.WriteLine($"Number of features:\t{NumberOfShapeFunctions + 1:D}");
writer.WriteLine("Feature Index\tFeature Name");
// REVIEW: We really need some unit tests around text exporting (for this, and other learners).
// A useful test in this case would be a model trained with:
// maml.exe train data=Samples\breast-cancer-withheader.txt loader=text{header+ col=Label:0 col=F1:1-4 col=F2:4 col=F3:5-*}
// xf =expr{col=F2 expr=x:0.0} xf=concat{col=Features:F1,F2,F3} tr=gam out=bubba2.zip
// Write out the intercept
writer.WriteLine("-1\tIntercept");
var names = default(VBuffer<ReadOnlyMemory<char>>);
AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, _numInputFeatures, ref names);
for (int internalIndex = 0; internalIndex < NumberOfShapeFunctions; internalIndex++)
{
int featureIndex = _shapeToInputMap[internalIndex];
var name = names.GetItemOrDefault(featureIndex);
writer.WriteLine(!name.IsEmpty ? "{0}\t{1}" : "{0}\tFeature {0}", featureIndex, name);
}
writer.WriteLine();
writer.WriteLine("Per feature binned effects:");
writer.WriteLine("Feature Index\tFeature Value Bin Upper Bound\tOutput (effect on label)");
writer.WriteLine($"{-1:D}\t{float.MaxValue:R}\t{Bias:R}");
for (int internalIndex = 0; internalIndex < NumberOfShapeFunctions; internalIndex++)
{
int featureIndex = _shapeToInputMap[internalIndex];
double[] effects = _binEffects[internalIndex];
double[] boundaries = _binUpperBounds[internalIndex];
for (int i = 0; i < effects.Length; ++i)
writer.WriteLine($"{featureIndex:D}\t{boundaries[i]:R}\t{effects[i]:R}");
}
}
void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
{
((ICanSaveInTextFormat)this).SaveAsText(writer, schema);
}
ValueMapper<TSrc, VBuffer<float>> IFeatureContributionMapper.GetFeatureContributionMapper<TSrc, TDstContributions>
(int top, int bottom, bool normalize)
{
Host.Check(typeof(TSrc) == typeof(VBuffer<float>), "Source type does not match.");
Host.Check(typeof(TDstContributions) == typeof(VBuffer<float>), "Destination type does not match.");
ValueMapper<VBuffer<float>, VBuffer<float>> del =
(in VBuffer<float> srcFeatures, ref VBuffer<float> dstContributions) =>
{
GetFeatureContributions(in srcFeatures, ref dstContributions, top, bottom, normalize);
};
return (ValueMapper<TSrc, VBuffer<float>>)(Delegate)del;
}
private void GetFeatureContributions(in VBuffer<float> features, ref VBuffer<float> contributions,
int top, int bottom, bool normalize)
{
var editor = VBufferEditor.Create(ref contributions, features.Length);
// We need to use dense value of features, b/c the feature contributions could be significant
// even for features with value 0.
var featureIndex = 0;
foreach (var featureValue in features.DenseValues())
{
float contribution = 0;
if (_inputFeatureToShapeFunctionMap.TryGetValue(featureIndex, out int j))
contribution = (float)GetBinEffect(j, featureValue);
editor.Values[featureIndex] = contribution;
featureIndex++;
}
contributions = editor.Commit();
Numeric.VectorUtils.SparsifyNormalize(ref contributions, top, bottom, normalize);
}
void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator)
{
Host.CheckValue(writer, nameof(writer), "writer must not be null");
var ensemble = new InternalTreeEnsemble();
for (int featureIndex = 0; featureIndex < NumberOfShapeFunctions; featureIndex++)
{
var effects = _binEffects[featureIndex];
var binThresholds = _binUpperBounds[featureIndex];
Host.Check(effects.Length == binThresholds.Length, "Effects array must be same length as binUpperBounds array.");
var numLeaves = effects.Length;
var numInternalNodes = numLeaves - 1;
var splitFeatures = Enumerable.Repeat(featureIndex, numInternalNodes).ToArray();
var (treeThresholds, lteChild, gtChild) = CreateBalancedTree(numInternalNodes, binThresholds);
var tree = CreateRegressionTree(numLeaves, splitFeatures, treeThresholds, lteChild, gtChild, effects);
ensemble.AddTree(tree);
}
// Adding the intercept as a dummy tree with the output values being the model intercept,
// works for reaching parity.
var interceptTree = CreateRegressionTree(
numLeaves: 2,
splitFeatures: new[] { 0 },
rawThresholds: new[] { 0f },
lteChild: new[] { ~0 },
gtChild: new[] { ~1 },
leafValues: new[] { Bias, Bias });
ensemble.AddTree(interceptTree);
var ini = FastTreeIniFileUtils.TreeEnsembleToIni(
Host, ensemble, schema, calibrator, string.Empty, false, false);
// Remove the SplitGain values which are all 0.
// It's eaiser to remove them here, than to modify the FastTree code.
var goodLines = ini.Split(new[] { '\n' }).Where(line => !line.StartsWith("SplitGain="));
ini = string.Join("\n", goodLines);
writer.WriteLine(ini);
}
// GAM bins should be converted to balanced trees / binary search trees
// so that scoring takes O(log(n)) instead of O(n). The following utility
// creates a balanced tree.
private (float[], int[], int[]) CreateBalancedTree(int numInternalNodes, double[] binThresholds)
{
var binIndices = Enumerable.Range(0, numInternalNodes).ToArray();
var internalNodeIndices = new List<int>();
var lteChild = new List<int>();
var gtChild = new List<int>();
var internalNodeId = numInternalNodes;
CreateBalancedTreeRecursive(
0, binIndices.Length - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
// internalNodeId should have been counted all the way down to 0 (root node)
Host.Assert(internalNodeId == 0);
var tree = (
thresholds: internalNodeIndices.Select(x => (float)binThresholds[binIndices[x]]).ToArray(),
lteChild: lteChild.ToArray(),
gtChild: gtChild.ToArray());
return tree;
}
private int CreateBalancedTreeRecursive(int lower, int upper,
List<int> internalNodeIndices, List<int> lteChild, List<int> gtChild, ref int internalNodeId)
{
if (lower > upper)
{
// Base case: we've reached a leaf node
Host.Assert(lower == upper + 1);
return ~lower;
}
else
{
// This is postorder traversal algorithm and populating the internalNodeIndices/lte/gt lists in reverse.
// Preorder is the only option, because we need the results of both left/right recursions for populating the lists.
// As a result, lists are populated in reverse, because the root node should be the first item on the lists.
// Binary search tree algorithm (recursive splitting to half) is used for creating balanced tree.
var mid = (lower + upper) / 2;
var left = CreateBalancedTreeRecursive(
lower, mid - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
var right = CreateBalancedTreeRecursive(
mid + 1, upper, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
internalNodeIndices.Insert(0, mid);
lteChild.Insert(0, left);
gtChild.Insert(0, right);
return --internalNodeId;
}
}
private static InternalRegressionTree CreateRegressionTree(
int numLeaves, int[] splitFeatures, float[] rawThresholds, int[] lteChild, int[] gtChild, double[] leafValues)
{
var numInternalNodes = numLeaves - 1;
return InternalRegressionTree.Create(
numLeaves: numLeaves,
splitFeatures: splitFeatures,
rawThresholds: rawThresholds,
lteChild: lteChild,
gtChild: gtChild.ToArray(),
leafValues: leafValues,
// Ignored arguments
splitGain: new double[numInternalNodes],
defaultValueForMissing: new float[numInternalNodes],
categoricalSplitFeatures: new int[numInternalNodes][],
categoricalSplit: new bool[numInternalNodes]);
}
/// <summary>
/// The GAM model visualization command. Because the data access commands must access private members of
/// <see cref="GamModelParametersBase"/>, it is convenient to have the command itself nested within the base
/// predictor class.
/// </summary>
internal sealed class VisualizationCommand : DataCommand.ImplBase<VisualizationCommand.Arguments>
{
public const string Summary = "Loads a model trained with a GAM learner, and starts an interactive web session to visualize it.";
public const string LoadName = "GamVisualization";
public sealed class Arguments : DataCommand.ArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to open the GAM visualization page URL", ShortName = "o", SortOrder = 3)]
public bool Open = true;
internal Arguments SetServerIfNeeded(IHostEnvironment env)
{
// We assume that if someone invoked this, they really did mean to start the web server.
if (env != null && Server == null)
Server = ServerChannel.CreateDefaultServerFactoryOrNull(env);
return this;
}
}
private readonly string _inputModelPath;
private readonly bool _open;
public VisualizationCommand(IHostEnvironment env, Arguments args)
: base(env, args.SetServerIfNeeded(env), LoadName)
{
Host.CheckValue(args, nameof(args));
Host.CheckValue(args.Server, nameof(args.Server));
Host.CheckNonWhiteSpace(args.InputModelFile, nameof(args.InputModelFile));
_inputModelPath = args.InputModelFile;
_open = args.Open;
}
public override void Run()
{
using (var ch = Host.Start("Run"))
{
Run(ch);
}
}
private sealed class Context
{
private readonly GamModelParametersBase _pred;
private readonly RoleMappedData _data;
private readonly VBuffer<ReadOnlyMemory<char>> _featNames;
// The scores.
private readonly float[] _scores;
// The labels.
private readonly float[] _labels;
// For every feature, and for every bin, there is a list of documents with that feature.
private readonly List<int>[][] _binDocsList;
// Whenever the predictor is "modified," we up this version. This value is returned for anything
// that is subject to change, and can be used by client web code to detect whenever something
// may have happened behind its back.
private long _version;
private long _saveVersion;
// Non-null if this object was created with an evaluator *and* scores and labels is non-empty.
private readonly RoleMappedData _dataForEvaluator;
// Non-null in the same conditions that the above is non-null.
private readonly IEvaluator _eval;
//the map of categorical indices, as defined in MetadataUtils
private readonly int[] _catsMap;
/// <summary>
/// These are the number of input features, as opposed to the number of features used within GAM
/// which may be lower.
/// </summary>
public int NumFeatures => _pred._inputType.Size;
public Context(IChannel ch, GamModelParametersBase pred, RoleMappedData data, IEvaluator eval)
{
Contracts.AssertValue(ch);
ch.AssertValue(pred);
ch.AssertValue(data);
ch.AssertValueOrNull(eval);
_saveVersion = -1;
_pred = pred;
_data = data;
var schema = _data.Schema;
var featCol = schema.Feature.Value;
int len = featCol.Type.GetValueCount();
ch.Check(len == _pred._numInputFeatures);
if (featCol.HasSlotNames(len))
featCol.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref _featNames);
else
_featNames = VBufferUtils.CreateEmpty<ReadOnlyMemory<char>>(len);
var numFeatures = _pred._binEffects.Length;
_binDocsList = new List<int>[numFeatures][];
for (int f = 0; f < numFeatures; f++)
{
var binDocList = new List<int>[_pred._binEffects[f].Length];
for (int e = 0; e < _pred._binEffects[f].Length; e++)
binDocList[e] = new List<int>();
_binDocsList[f] = binDocList;
}
var labels = new List<float>();
var scores = new List<float>();
int[] bins = new int[numFeatures];
using (var cursor = new FloatLabelCursor(_data, CursOpt.Label | CursOpt.Features))
{
int doc = 0;
while (cursor.MoveNext())
{
labels.Add(cursor.Label);
var score = _pred.GetFeatureBinsAndScore(in cursor.Features, bins);
scores.Add((float)score);
for (int f = 0; f < numFeatures; f++)
_binDocsList[f][bins[f]].Add(doc);
++doc;
}
_labels = labels.ToArray();
labels = null;
_scores = scores.ToArray();
scores = null;
}
ch.Assert(_scores.Length == _labels.Length);
if (_labels.Length > 0 && eval != null)
{
_eval = eval;
var builder = new ArrayDataViewBuilder(pred.Host);
builder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, _labels);
builder.AddColumn(DefaultColumnNames.Score, NumberDataViewType.Single, _scores);
_dataForEvaluator = new RoleMappedData(builder.GetDataView(), opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(DefaultColumnNames.Label),
new RoleMappedSchema.ColumnRole(AnnotationUtils.Const.ScoreValueKind.Score).Bind(DefaultColumnNames.Score));
}
var featureCol = _data.Schema.Schema[DefaultColumnNames.Features];
AnnotationUtils.TryGetCategoricalFeatureIndices(_data.Schema.Schema, featureCol.Index, out _catsMap);
}
public FeatureInfo GetInfoForIndex(int index) => FeatureInfo.GetInfoForIndex(this, index);
public IEnumerable<FeatureInfo> GetInfos() => FeatureInfo.GetInfos(this);
public long SetEffect(int feat, int bin, double effect)
{
// Another version with multiple effects, perhaps?
int internalIndex;
if (!_pred._inputFeatureToShapeFunctionMap.TryGetValue(feat, out internalIndex))
return -1;
var effects = _pred._binEffects[internalIndex];
if (bin < 0 || bin > effects.Length)
return -1;
lock (_pred)
{
var deltaEffect = effect - effects[bin];
effects[bin] = effect;
foreach (var docIndex in _binDocsList[internalIndex][bin])
_scores[docIndex] += (float)deltaEffect;
return checked(++_version);
}
}
public MetricsInfo GetMetrics()
{
if (_eval == null)
return null;
lock (_pred)
{
var metricDict = _eval.Evaluate(_dataForEvaluator);
IDataView metricsView;
if (!metricDict.TryGetValue(MetricKinds.OverallMetrics, out metricsView))
return null;
Contracts.AssertValue(metricsView);
return new MetricsInfo(_version, EvaluateUtils.GetMetrics(metricsView).ToArray());
}
}
/// <summary>
/// This will write out a file, if needed. In all cases if something is written it will return
/// a version number, with an indication based on sign of whether anything was actually written
/// in this call.
/// </summary>
/// <param name="host">The host from the command</param>
/// <param name="ch">The channel from the command</param>
/// <param name="outFile">The (optionally empty) output file</param>
/// <returns>Returns <c>null</c> if the model file could not be saved because <paramref name="outFile"/>
/// was <c>null</c> or whitespace. Otherwise, if the current version if newer than the last version saved,
/// it will save, and return that version. (In this case, the number is non-negative.) Otherwise, if the current
/// version was the last version saved, then it will return the bitwise not of that version number (in this case,
/// the number is negative).</returns>
public long? SaveIfNeeded(IHost host, IChannel ch, string outFile)
{
Contracts.AssertValue(ch);
ch.AssertValue(host);
ch.AssertValueOrNull(outFile);
if (string.IsNullOrWhiteSpace(outFile))
return null;
lock (_pred)
{
ch.Assert(_saveVersion <= _version);
if (_saveVersion == _version)
return ~_version;
// Note that this data pipe is the data pipe that was defined for the gam visualization
// command, which may not be quite the same thing as the data pipe in the original model,
// in the event that the user specified different loader settings, defined new transforms,
// etc.
using (var file = host.CreateOutputFile(outFile))
TrainUtils.SaveModel(host, ch, file, _pred, _data);
return _saveVersion = _version;
}
}
public sealed class MetricsInfo
{
public long Version { get; }
public KeyValuePair<string, double>[] Metrics { get; }
public MetricsInfo(long version, KeyValuePair<string, double>[] metrics)
{
Version = version;
Metrics = metrics;
}
}
public sealed class FeatureInfo
{
public int Index { get; }
public string Name { get; }
/// <summary>
/// The upper bounds of each bin.
/// </summary>
public IEnumerable<double> UpperBounds { get; }
/// <summary>
/// The amount added to the model for a document falling in a given bin.
/// </summary>
public IEnumerable<double> BinEffects { get; }
/// <summary>
/// The number of documents in each bin.
/// </summary>
public IEnumerable<int> DocCounts { get; }
/// <summary>
/// The version of the GAM context that has these values.
/// </summary>
public long Version { get; }
/// <summary>
/// For features belonging to the same categorical, this value will be the same,
/// Set to -1 for non-categoricals.
/// </summary>
public int CategoricalFeatureIndex { get; }
private FeatureInfo(Context context, int index, int internalIndex, int[] catsMap)
{
Contracts.AssertValue(context);
Contracts.Assert(context._pred._inputFeatureToShapeFunctionMap.ContainsKey(index)
&& context._pred._inputFeatureToShapeFunctionMap[index] == internalIndex);
Index = index;
var name = context._featNames.GetItemOrDefault(index).ToString();
Name = string.IsNullOrEmpty(name) ? $"f{index}" : name;
var up = context._pred._binUpperBounds[internalIndex];
UpperBounds = up.Take(up.Length - 1);
BinEffects = context._pred._binEffects[internalIndex];
DocCounts = context._binDocsList[internalIndex].Select(Utils.Size);
Version = context._version;
CategoricalFeatureIndex = -1;
if (catsMap != null && index < catsMap[catsMap.Length - 1])
{
for (int i = 0; i < catsMap.Length; i += 2)
{
if (index >= catsMap[i] && index <= catsMap[i + 1])
{
CategoricalFeatureIndex = i;
break;
}
}
}
}
public static FeatureInfo GetInfoForIndex(Context context, int index)
{
Contracts.AssertValue(context);
Contracts.Assert(0 <= index && index < context._pred._inputType.Size);
lock (context._pred)
{
int internalIndex;
if (!context._pred._inputFeatureToShapeFunctionMap.TryGetValue(index, out internalIndex))
return null;
return new FeatureInfo(context, index, internalIndex, context._catsMap);
}
}
public static FeatureInfo[] GetInfos(Context context)
{
lock (context._pred)
{
return Utils.BuildArray(context._pred.NumberOfShapeFunctions,
i => new FeatureInfo(context, context._pred._shapeToInputMap[i], i, context._catsMap));
}
}
}
}
/// <summary>
/// Attempts to initialize required items, from the input model file. It could throw if something goes wrong.
/// </summary>
/// <param name="ch">The channel</param>
/// <returns>A structure containing essential information about the GAM dataset that enables
/// operations on top of that structure.</returns>
private Context Init(IChannel ch)
{
ILegacyDataLoader loader;
IPredictor rawPred;
RoleMappedSchema schema;
LoadModelObjects(ch, true, out rawPred, true, out schema, out loader);
bool hadCalibrator = false;
// The rawPred has two possible types:
// 1. CalibratedPredictorBase<BinaryClassificationGamModelParameters, PlattCalibrator>
// 2. RegressionGamModelParameters
// For (1), the trained model, GamModelParametersBase, is a field we need to extract. For (2),
// we don't need to do anything because RegressionGamModelParameters is derived from GamModelParametersBase.
var calibrated = rawPred as CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>;
while (calibrated != null)
{
hadCalibrator = true;
rawPred = calibrated.SubModel;
calibrated = rawPred as CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>;
}
var pred = rawPred as GamModelParametersBase;
ch.CheckUserArg(pred != null, nameof(ImplOptions.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase));
var data = new RoleMappedData(loader, schema.GetColumnRoleNames(), opt: true);
if (hadCalibrator && !string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
ch.Warning("If you save the GAM model, only the GAM model, not the wrapping calibrator, will be saved.");
return new Context(ch, pred, data, InitEvaluator(pred));
}
private IEvaluator InitEvaluator(GamModelParametersBase pred)
{
switch (pred.PredictionKind)
{
case PredictionKind.BinaryClassification:
return new BinaryClassifierEvaluator(Host, new BinaryClassifierEvaluator.Arguments());
case PredictionKind.Regression:
return new RegressionEvaluator(Host, new RegressionEvaluator.Arguments());
default:
return null;
}
}
private void Run(IChannel ch)
{
// First we're going to initialize a structure with lots of information about the predictor, trainer, etc.
var context = Init(ch);
// REVIEW: What to do with the data? Not sure. Take a sample? We could have
// a very compressed one, since we can just "bin" everything based on pred._binUpperBounds. Anyway
// whatever we choose to do, ultimately it will be exposed as some delegate on the server channel.
// Maybe binning actually isn't wise, *if* we want people to be able to set their own split points
// (which seems plausible). In the current version of the viz you can only set bin effects, but
// "splitting" a bin might be desirable in some cases, maybe. Or not.
// Now we have a gam predictor,
AutoResetEvent ev = new AutoResetEvent(false);
using (var server = InitServer(ch))
using (var sch = Host.StartServerChannel("predictor/gam"))
{
// The number of features.
sch?.Register("numFeatures", () => context.NumFeatures);
// Info for a particular feature.
sch?.Register<int, Context.FeatureInfo>("info", context.GetInfoForIndex);
// Info for all features.
sch?.Register("infos", context.GetInfos);
// Modification of the model.
sch?.Register<int, int, double, long>("setEffect", context.SetEffect);
// Getting the metrics.
sch?.Register("metrics", context.GetMetrics);
sch?.Register("canSave", () => !string.IsNullOrEmpty(ImplOptions.OutputModelFile));
sch?.Register("save", () => context.SaveIfNeeded(Host, ch, ImplOptions.OutputModelFile));
sch?.Register("quit", () =>
{
var retVal = context.SaveIfNeeded(Host, ch, ImplOptions.OutputModelFile);
ev.Set();
return retVal;
});
// Targets and scores for data.
sch?.Publish();
if (sch != null)
{
ch.Info("GAM viz server is ready and waiting.");
Uri uri = server.BaseAddress;
// Believe it or not, this is actually the recommended procedure according to MSDN.
if (_open)
System.Diagnostics.Process.Start(uri.AbsoluteUri + "content/GamViz/");
ev.WaitOne();
ch.Info("Quit signal received. Quitter.");
}
else
ch.Info("No server, exiting immediately.");
}
}
}
}
}
|