|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
namespace Microsoft.ML.Trainers
{
/// <summary>
/// Options for creating a <see cref="TrainingCursorBase"/> from a <see cref="RoleMappedData"/> with specified standard columns active.
/// </summary>
[Flags]
[BestFriend]
internal enum CursOpt : uint
{
Weight = 0x01,
Group = 0x02,
Id = 0x04,
Label = 0x08,
Features = 0x10,
// Row filtering options.
AllowBadWeights = 0x0100,
AllowBadGroups = 0x0200,
AllowBadLabels = 0x0800,
AllowBadFeatures = 0x1000,
// Bad to the bone.
AllowBadEverything = AllowBadWeights | AllowBadGroups | AllowBadLabels | AllowBadFeatures,
AllWeights = Weight | AllowBadWeights,
AllGroups = Group | AllowBadGroups,
AllLabels = Label | AllowBadLabels,
AllFeatures = Features | AllowBadFeatures,
}
[BestFriend]
internal static class TrainerUtils
{
/// <summary>
/// Check for a standard (known-length vector of float) feature column.
/// </summary>
public static void CheckFeatureFloatVector(this RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
if (!data.Schema.Feature.HasValue)
throw Contracts.ExceptParam(nameof(data), "Training data must specify a feature column.");
var col = data.Schema.Feature.Value;
Contracts.Assert(!col.IsHidden);
if (!(col.Type is VectorDataViewType vecType && vecType.Size > 0 && vecType.ItemType == NumberDataViewType.Single))
throw Contracts.ExceptParam(nameof(data), "Training feature column '{0}' must be a known-size vector of R4, but has type: {1}.", col.Name, col.Type);
}
/// <summary>
/// Check for a standard (known-length vector of float) feature column and determine its length.
/// </summary>
public static void CheckFeatureFloatVector(this RoleMappedData data, out int length)
{
CheckFeatureFloatVector(data);
// If the above function is generalized, this needs to be as well.
Contracts.AssertValue(data);
Contracts.Assert(data.Schema.Feature.HasValue);
var col = data.Schema.Feature.Value;
Contracts.Assert(!col.IsHidden);
var colType = col.Type as VectorDataViewType;
Contracts.Assert(colType != null && colType.IsKnownSize);
Contracts.Assert(colType.ItemType == NumberDataViewType.Single);
length = colType.Size;
}
/// <summary>
/// Check for a standard binary classification label.
/// </summary>
public static void CheckBinaryLabel(this RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
if (!data.Schema.Label.HasValue)
throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column.");
var col = data.Schema.Label.Value;
Contracts.Assert(!col.IsHidden);
if (col.Type != BooleanDataViewType.Instance && col.Type != NumberDataViewType.Single && col.Type != NumberDataViewType.Double && !(col.Type is KeyDataViewType keyType && keyType.Count == 2))
{
KeyDataViewType colKeyType = col.Type as KeyDataViewType;
if (colKeyType != null)
{
if (colKeyType.Count == 1)
{
throw Contracts.ExceptParam(nameof(data),
"The label column '{0}' of the training data has only one class. Two classes are required for binary classification.",
col.Name);
}
else if (colKeyType.Count > 2)
{
throw Contracts.ExceptParam(nameof(data),
"The label column '{0}' of the training data has more than two classes. Only two classes are allowed for binary classification.",
col.Name);
}
}
throw Contracts.ExceptParam(nameof(data),
"The label column '{0}' of the training data has a data type not suitable for binary classification: {1}. Type must be Boolean, Single, Double or Key with two classes.",
col.Name, col.Type);
}
}
/// <summary>
/// Check for a standard regression label.
/// </summary>
public static void CheckRegressionLabel(this RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
if (!data.Schema.Label.HasValue)
throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column.");
var col = data.Schema.Label.Value;
Contracts.Assert(!data.Schema.Schema[col.Index].IsHidden);
if (col.Type != NumberDataViewType.Single && col.Type != NumberDataViewType.Double)
{
throw Contracts.ExceptParam(nameof(data),
"Training label column '{0}' type isn't suitable for regression: {1}. Type must be Single or Double.", col.Name, col.Type);
}
}
/// <summary>
/// Check for a standard multi-class label and determine its cardinality. If the column is a
/// key type, it must have known cardinality. For other numeric types, this scans the data
/// to determine the cardinality.
/// </summary>
public static void CheckMulticlassLabel(this RoleMappedData data, out int count)
{
Contracts.CheckValue(data, nameof(data));
if (!data.Schema.Label.HasValue)
throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column.");
var col = data.Schema.Label.Value;
Contracts.Assert(!col.IsHidden);
if (col.Type is KeyDataViewType keyType && keyType.Count > 0)
{
if (keyType.Count >= Utils.ArrayMaxSize)
throw Contracts.ExceptParam(nameof(data), "Maximum label is too large for multi-class: {0}.", keyType.Count);
count = (int)keyType.Count;
return;
}
// REVIEW: Support other numeric types.
if (col.Type != NumberDataViewType.Single && col.Type != NumberDataViewType.Double)
throw Contracts.ExceptParam(nameof(data), "Training label column '{0}' type is not valid for multi-class: {1}. Type must be Single or Double.", col.Name, col.Type);
int max = -1;
using (var cursor = new FloatLabelCursor(data))
{
while (cursor.MoveNext())
{
int cls = (int)cursor.Label;
if (cls != cursor.Label || cls < 0)
{
throw Contracts.ExceptParam(nameof(data),
"Training label column '{0}' contains invalid values for multi-class: {1}.", col.Name, cursor.Label);
}
if (max < cls)
max = cls;
}
}
if (max < 0)
throw Contracts.ExceptParam(nameof(data), "Training label column '{0}' contains no valid values for multi-class.", col.Name);
// REVIEW: Should we impose some smaller limit on the max?
if (max >= Utils.ArrayMaxSize)
throw Contracts.ExceptParam(nameof(data), "Maximum label is too large for multi-class: {0}.", max);
count = max + 1;
}
/// <summary>
/// Check for a standard regression label.
/// </summary>
public static void CheckMultiOutputRegressionLabel(this RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
if (!data.Schema.Label.HasValue)
throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column.");
var col = data.Schema.Label.Value;
Contracts.Assert(!col.IsHidden);
if (!(col.Type is VectorDataViewType vectorType
&& vectorType.IsKnownSize
&& vectorType.ItemType == NumberDataViewType.Single))
throw Contracts.ExceptParam(nameof(data), "Training label column '{0}' must be a known-size vector of Single, but has type: {1}.", col.Name, col.Type);
}
public static void CheckOptFloatWeight(this RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
if (!data.Schema.Weight.HasValue)
return;
var col = data.Schema.Weight.Value;
Contracts.Assert(!col.IsHidden);
if (col.Type != NumberDataViewType.Single && col.Type != NumberDataViewType.Double)
throw Contracts.ExceptParam(nameof(data), "Training weight column '{0}' must be of floating point numeric type, but has type: {1}.", col.Name, col.Type);
}
public static void CheckOptGroup(this RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
if (!data.Schema.Group.HasValue)
return;
var col = data.Schema.Group.Value;
Contracts.Assert(!col.IsHidden);
if (col.Type is KeyDataViewType)
return;
throw Contracts.ExceptParam(nameof(data), "Training group column '{0}' type is invalid: {1}. Must be Key type.", col.Name, col.Type);
}
private static IEnumerable<DataViewSchema.Column> CreatePredicate(RoleMappedData data, CursOpt opt, IEnumerable<int> extraCols)
{
Contracts.AssertValue(data);
Contracts.AssertValueOrNull(extraCols);
var columns = extraCols == null ?
new List<DataViewSchema.Column>() :
data.Data.Schema.Where(c => extraCols.Contains(c.Index)).ToList();
if ((opt & CursOpt.Label) != 0 && data.Schema.Label.HasValue)
columns.Add(data.Schema.Label.Value);
if ((opt & CursOpt.Features) != 0 && data.Schema.Feature.HasValue)
columns.Add(data.Schema.Feature.Value);
if ((opt & CursOpt.Weight) != 0 && data.Schema.Weight.HasValue)
columns.Add(data.Schema.Weight.Value);
if ((opt & CursOpt.Group) != 0 && data.Schema.Group.HasValue)
columns.Add(data.Schema.Group.Value);
return columns;
}
/// <summary>
/// Create a row cursor for the RoleMappedData with the indicated standard columns active.
/// This does not verify that the columns exist, but merely activates the ones that do exist.
/// </summary>
public static DataViewRowCursor CreateRowCursor(this RoleMappedData data, CursOpt opt, Random rand, IEnumerable<int> extraCols = null)
=> data.Data.GetRowCursor(CreatePredicate(data, opt, extraCols), rand);
/// <summary>
/// Create a row cursor set for the <see cref="RoleMappedData"/> with the indicated standard columns active.
/// This does not verify that the columns exist, but merely activates the ones that do exist.
/// </summary>
public static DataViewRowCursor[] CreateRowCursorSet(this RoleMappedData data,
CursOpt opt, int n, Random rand, IEnumerable<int> extraCols = null)
=> data.Data.GetRowCursorSet(CreatePredicate(data, opt, extraCols), n, rand);
/// <summary>
/// Get the getter for the feature column, assuming it is a vector of float.
/// </summary>
public static ValueGetter<VBuffer<float>> GetFeatureFloatVectorGetter(this DataViewRow row, RoleMappedSchema schema)
{
Contracts.CheckValue(row, nameof(row));
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!");
Contracts.CheckParam(schema.Feature.HasValue, nameof(schema), "Missing feature column");
return row.GetGetter<VBuffer<float>>(schema.Feature.Value);
}
/// <summary>
/// Get the getter for the feature column, assuming it is a vector of float.
/// </summary>
public static ValueGetter<VBuffer<float>> GetFeatureFloatVectorGetter(this DataViewRow row, RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
return GetFeatureFloatVectorGetter(row, data.Schema);
}
/// <summary>
/// Get a getter for the label as a float. This assumes that the label column type
/// has already been validated as appropriate for the kind of training being done.
/// </summary>
public static ValueGetter<float> GetLabelFloatGetter(this DataViewRow row, RoleMappedSchema schema)
{
Contracts.CheckValue(row, nameof(row));
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!");
Contracts.CheckParam(schema.Label.HasValue, nameof(schema), "Missing label column");
return RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
}
/// <summary>
/// Get a getter for the label as a float. This assumes that the label column type
/// has already been validated as appropriate for the kind of training being done.
/// </summary>
public static ValueGetter<float> GetLabelFloatGetter(this DataViewRow row, RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
return GetLabelFloatGetter(row, data.Schema);
}
/// <summary>
/// Get the getter for the weight column, or null if there is no weight column.
/// </summary>
public static ValueGetter<float> GetOptWeightFloatGetter(this DataViewRow row, RoleMappedSchema schema)
{
Contracts.CheckValue(row, nameof(row));
Contracts.CheckValue(schema, nameof(schema));
Contracts.Check(schema.Schema == row.Schema, "schemas don't match!");
var col = schema.Weight;
if (!col.HasValue)
return null;
return RowCursorUtils.GetGetterAs<float>(NumberDataViewType.Single, row, col.Value.Index);
}
public static ValueGetter<float> GetOptWeightFloatGetter(this DataViewRow row, RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
return GetOptWeightFloatGetter(row, data.Schema);
}
/// <summary>
/// Get the getter for the group column, or null if there is no group column.
/// </summary>
public static ValueGetter<ulong> GetOptGroupGetter(this DataViewRow row, RoleMappedSchema schema)
{
Contracts.CheckValue(row, nameof(row));
Contracts.CheckValue(schema, nameof(schema));
Contracts.Check(schema.Schema == row.Schema, "schemas don't match!");
var col = schema.Group;
if (!col.HasValue)
return null;
return RowCursorUtils.GetGetterAs<ulong>(NumberDataViewType.UInt64, row, col.Value.Index);
}
public static ValueGetter<ulong> GetOptGroupGetter(this DataViewRow row, RoleMappedData data)
{
Contracts.CheckValue(data, nameof(data));
return GetOptGroupGetter(row, data.Schema);
}
/// <summary>
/// The <see cref="SchemaShape.Column"/> for the label column for binary classification tasks.
/// </summary>
/// <param name="labelColumn">name of the label column</param>
public static SchemaShape.Column MakeBoolScalarLabel(string labelColumn)
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false);
/// <summary>
/// The <see cref="SchemaShape.Column"/> for the float type columns.
/// </summary>
/// <param name="columnName">name of the column</param>
public static SchemaShape.Column MakeR4ScalarColumn(string columnName)
=> new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false);
/// <summary>
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
/// </summary>
/// <param name="columnName">name of the weight column</param>
public static SchemaShape.Column MakeU4ScalarColumn(string columnName)
{
if (columnName == null)
return default;
return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);
}
/// <summary>
/// The <see cref="SchemaShape.Column"/> for the feature column.
/// </summary>
/// <param name="featureColumn">name of the feature column</param>
public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
=> new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
/// <summary>
/// The <see cref="SchemaShape.Column"/> for the weight column.
/// </summary>
/// <param name="weightColumn">name of the weight column</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
{
if (weightColumn == null)
return default;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false);
}
/// <summary>
/// This is a shim class to translate the more contemporaneous <see cref="ITrainerEstimator{TTransformer, TPredictor}"/>
/// style transformers into the older now disfavored <see cref="ITrainer{TPredictor}"/> idiom, for components that still
/// need to operate via that older mechanism. (Mostly command line invocations, and so on.).
/// </summary>
/// <typeparam name="TModel">The type of the new model parameters.</typeparam>
/// <typeparam name="TPredictor">The type corresponding to the legacy predictor.</typeparam>
private sealed class TrainerEstimatorToTrainerShim<TModel, TPredictor> : ITrainer<TPredictor>
where TModel : class, TPredictor
where TPredictor : IPredictor
{
public TrainerInfo Info { get; }
public PredictionKind PredictionKind { get; }
private readonly ITrainerEstimator<ISingleFeaturePredictionTransformer<TModel>, TModel> _trainer;
private readonly IHostEnvironment _env;
public TrainerEstimatorToTrainerShim(IHostEnvironment env, ITrainerEstimator<ISingleFeaturePredictionTransformer<TModel>, TModel> trainer)
{
Contracts.AssertValue(env);
_env = env;
_env.AssertValue(trainer);
_env.Assert(trainer is ITrainer);
var oldTrainer = (ITrainer)trainer;
Info = oldTrainer.Info;
PredictionKind = oldTrainer.PredictionKind;
_trainer = trainer;
}
public TPredictor Train(TrainContext context)
{
_env.CheckValue(context, nameof(context));
// For the purpose of mapping into the estimator, we assume that the input estimator does not have
// any custom overrides for the column names defined.
var tschema = context.TrainingSet.Schema;
var nameMap = new List<(string outName, string inName)>();
if (tschema.Feature?.Name is string fname && fname != DefaultColumnNames.Features)
nameMap.Add((DefaultColumnNames.Features, fname));
if (tschema.Label?.Name is string lname && lname != DefaultColumnNames.Label)
nameMap.Add((DefaultColumnNames.Label, lname));
if (tschema.Weight?.Name is string wname && wname != DefaultColumnNames.Weight)
nameMap.Add((DefaultColumnNames.Weight, wname));
if (tschema.Group?.Name is string gname && gname != DefaultColumnNames.GroupId)
nameMap.Add((DefaultColumnNames.GroupId, gname));
if (tschema.Group?.Name is string iname && iname != DefaultColumnNames.Item)
nameMap.Add((DefaultColumnNames.Item, iname));
if (tschema.Group?.Name is string uname && uname != DefaultColumnNames.User)
nameMap.Add((DefaultColumnNames.User, uname));
var data = context.TrainingSet.Data;
if (nameMap.Count > 0)
{
var estimator = new ColumnCopyingEstimator(_env, nameMap.ToArray());
data = estimator.Fit(data).Transform(data);
}
var predictionTransformer = _trainer.Fit(data);
var model = predictionTransformer.Model;
if (model is TPredictor pred)
return pred;
throw _env.Except($"Training resulted in a model of type {model.GetType().Name}.");
}
IPredictor ITrainer.Train(TrainContext context) => Train(context);
}
/// <summary>
/// This is a shim for legacy code that takes the more modern <see cref="ITrainerEstimator{TTransformer, TPredictor}"/>
/// interface, and maps it to the legacy code that wants an <see cref="ITrainer{TPredictor}"/>. The goal should be to
/// remove reliance on that interface if possible, but this may not be practical in the immediate term, so for the benefit
/// of scenarios like this we have this convenience function.
/// </summary>
/// <typeparam name="T">The trainer estimator type.</typeparam>
/// <typeparam name="TModel">The type of the model produced by the estimator.</typeparam>
/// <typeparam name="TPredictor">The type of the predictor to be produced by the predictor.</typeparam>
/// <param name="env">The host environment.</param>
/// <param name="trainer">The trainer estimator.</param>
/// <returns>An implementation of the legacy trainer interface.</returns>
public static ITrainer<TPredictor> MapTrainerEstimatorToTrainer<T, TModel, TPredictor>(IHostEnvironment env, T trainer)
where T : ITrainerEstimator<ISingleFeaturePredictionTransformer<TModel>, TModel>, ITrainer
where TModel : class, TPredictor
where TPredictor : IPredictor
{
return new TrainerEstimatorToTrainerShim<TModel, TPredictor>(env, trainer);
}
}
/// <summary>
/// This is the base class for a data cursor. Data cursors are specially typed
/// "convenience" cursor-like objects, less general than a <see cref="DataViewRowCursor"/> but
/// more convenient for common access patterns that occur in machine learning. For
/// example, the common idiom of iterating over features/labels/weights while skipping
/// "bad" features, labels, and weights. There will be two typical access patterns for
/// users of the cursor. The first is just creation of the cursor using a constructor;
/// this is best for one-off accesses of the data. The second access pattern, best for
/// repeated accesses, is to use a cursor factory (usually a nested class of the cursor
/// class). This keeps track of what filtering options were actually useful.
/// </summary>
[BestFriend]
internal abstract class TrainingCursorBase : IDisposable
{
public DataViewRow Row => _cursor;
private readonly DataViewRowCursor _cursor;
private readonly Action<CursOpt> _signal;
public long SkippedRowCount { get; private set; }
public long KeptRowCount { get; private set; }
/// <summary>
/// The base constructor class for the factory-based cursor creation.
/// </summary>
/// <param name="input"></param>
/// <param name="signal">This method is called </param>
protected TrainingCursorBase(DataViewRowCursor input, Action<CursOpt> signal)
{
Contracts.AssertValue(input);
Contracts.AssertValueOrNull(signal);
_cursor = input;
_signal = signal;
}
protected static DataViewRowCursor CreateCursor(RoleMappedData data, CursOpt opt, Random rand, params int[] extraCols)
{
Contracts.AssertValue(data);
Contracts.AssertValueOrNull(rand);
return data.CreateRowCursor(opt, rand, extraCols);
}
/// <summary>
/// This method is called by <see cref="MoveNext"/> in the event we have reached the end
/// of the cursoring. The intended usage is that it returns what flags will be passed to the signal
/// delegate of the cursor, indicating what additional options should be specified on subsequent
/// passes over the data. The base implementation checks if any rows were skipped, and if none were
/// skipped, it signals the context that it needn't bother with any filtering checks.
///
/// Because the result will be "or"-red, a perfectly acceptable implementation is that this
/// return the default <see cref="CursOpt"/>, in which case the flags will not ever change.
///
/// If the cursor was created with a signal delegate, the return value of this method will be sent
/// to that delegate.
/// </summary>
protected virtual CursOpt CursoringCompleteFlags()
{
return SkippedRowCount == 0 ? CursOpt.AllowBadEverything : default(CursOpt);
}
/// <summary>
/// Calls Cursor.MoveNext() and this.Accept() repeatedly until this.Accept() returns true.
/// Returns false if Cursor.MoveNext() returns false. If you call Cursor.MoveNext() directly,
/// also call this.Accept() to fetch the values of the current row. Note that if this.Accept()
/// returns false, it's possible that not all values were fetched.
/// </summary>
public bool MoveNext()
{
for (; ; )
{
if (!_cursor.MoveNext())
{
if (_signal != null)
_signal(CursoringCompleteFlags());
return false;
}
if (Accept())
{
KeptRowCount++;
return true;
}
SkippedRowCount++;
}
}
/// <summary>
/// This fetches and validates values for the standard active columns.
/// It is called automatically by MoveNext(). Client code should only need
/// to deal with this if it calls MoveNext() on the underlying
/// <see cref="DataViewRowCursor"/> directly. That is, this is only for very advanced scenarios.
/// </summary>
public virtual bool Accept()
{
return true;
}
public void Dispose()
{
_cursor.Dispose();
}
/// <summary>
/// This is the base class for a data cursor factory. The factory is a reusable object,
/// created with data and cursor options. From external non-implementing users it will
/// appear to be more or less stateless, but internally it is keeping track of what sorts
/// of filtering it needs to perform. For example, if we construct the factory with the
/// option that it needs to filter out rows with bad feature values, but on the first
/// iteration it is revealed there are no bad feature values, then it would be a complete
/// waste of time to check on subsequent iterations over the data whether there are bad
/// feature values again.
/// </summary>
public abstract class FactoryBase<TCurs>
where TCurs : TrainingCursorBase
{
private readonly RoleMappedData _data;
private readonly CursOpt _initOpts;
private readonly object _lock;
private CursOpt _opts;
public RoleMappedData Data => _data;
protected FactoryBase(RoleMappedData data, CursOpt opt)
{
Contracts.CheckValue(data, nameof(data));
_data = data;
_opts = _initOpts = opt;
_lock = new object();
}
private void SignalCore(CursOpt opt)
{
lock (_lock)
_opts |= opt;
}
/// <summary>
/// The typed analog to <see cref="IDataView.GetRowCursor(IEnumerable{DataViewSchema.Column},Random)"/>.
/// </summary>
/// <param name="rand">Non-null if we are requesting a shuffled cursor.</param>
/// <param name="extraCols">The extra columns to activate on the row cursor
/// in addition to those required by the factory's options.</param>
/// <returns>The wrapping typed cursor.</returns>
public TCurs Create(Random rand = null, params int[] extraCols)
{
CursOpt opt;
lock (_lock)
opt = _opts;
var input = _data.CreateRowCursor(opt, rand, extraCols);
return CreateCursorCore(input, _data, opt, SignalCore);
}
/// <summary>
/// The typed analog to <see cref="IDataView.GetRowCursorSet"/>, this provides a
/// partitioned cursoring of the data set, appropriate to multithreaded algorithms
/// that want to consume parallel cursors without any consolidation.
/// </summary>
/// <param name="n">Suggested degree of parallelism.</param>
/// <param name="rand">Non-null if we are requesting a shuffled cursor.</param>
/// <param name="extraCols">The extra columns to activate on the row cursor
/// in addition to those required by the factory's options.</param>
/// <returns>The cursor set. Note that this needn't necessarily be of size
/// <paramref name="n"/>.</returns>
public TCurs[] CreateSet(int n, Random rand = null, params int[] extraCols)
{
CursOpt opt;
lock (_lock)
opt = _opts;
// Users of this method will tend to consume the cursors in the set in separate
// threads, and so gain benefit from the parallel transformation of the data.
var inputs = _data.CreateRowCursorSet(opt, n, rand, extraCols);
Contracts.Assert(Utils.Size(inputs) > 0);
Action<CursOpt> signal;
if (inputs.Length > 1)
signal = new AndAccumulator(SignalCore, inputs.Length).Signal;
else
signal = SignalCore;
var res = new TCurs[inputs.Length];
for (int i = 0; i < res.Length; i++)
res[i] = CreateCursorCore(inputs[i], _data, opt, signal);
return res;
}
/// <summary>
/// Called by both the <see cref="Create"/> and <see cref="CreateSet"/> factory methods. Implementors
/// should instantiate the particular wrapping cursor.
/// </summary>
/// <param name="input">The row cursor we will wrap.</param>
/// <param name="data">The data from which the row cursor was instantiated.</param>
/// <param name="opt">The cursor options this row cursor was created with.</param>
/// <param name="signal">The action that our wrapping cursor will call. Implementors of the cursor
/// do not usually call it directly, but instead override
/// <see cref="TrainingCursorBase.CursoringCompleteFlags"/>, whose return value is used to call
/// this action.</param>
/// <returns></returns>
protected abstract TCurs CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal);
/// <summary>
/// Accumulates signals from cursors, anding them together. Once it has
/// all of the information it needs to signal the factory itself, it will
/// do so.
/// </summary>
private sealed class AndAccumulator
{
private readonly Action<CursOpt> _signal;
private readonly int _lim;
private int _count;
private CursOpt _opts;
public AndAccumulator(Action<CursOpt> signal, int lim)
{
Contracts.AssertValue(signal);
Contracts.Assert(lim > 0);
_signal = signal;
_lim = lim;
_opts = ~default(CursOpt);
}
public void Signal(CursOpt opt)
{
lock (this)
{
Contracts.Assert(_count < _lim);
_opts &= opt;
if (++_count == _lim)
_signal(_opts);
}
}
}
}
}
/// <summary>
/// This supports Weight (float), Group (ulong), and Id (RowId) columns.
/// </summary>
[BestFriend]
internal class StandardScalarCursor : TrainingCursorBase
{
private readonly ValueGetter<float> _getWeight;
private readonly ValueGetter<ulong> _getGroup;
private readonly ValueGetter<DataViewRowId> _getId;
private readonly bool _keepBadWeight;
private readonly bool _keepBadGroup;
public long BadWeightCount { get; private set; }
public long BadGroupCount { get; private set; }
public float Weight;
public ulong Group;
public DataViewRowId Id;
public StandardScalarCursor(RoleMappedData data, CursOpt opt, Random rand = null, params int[] extraCols)
: this(CreateCursor(data, opt, rand, extraCols), data, opt)
{
}
protected StandardScalarCursor(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal = null)
: base(input, signal)
{
Contracts.AssertValue(data);
if ((opt & CursOpt.Weight) != 0)
{
_getWeight = Row.GetOptWeightFloatGetter(data);
_keepBadWeight = (opt & CursOpt.AllowBadWeights) != 0;
}
if ((opt & CursOpt.Group) != 0)
{
_getGroup = Row.GetOptGroupGetter(data);
_keepBadGroup = (opt & CursOpt.AllowBadGroups) != 0;
}
if ((opt & CursOpt.Id) != 0)
_getId = Row.GetIdGetter();
Weight = 1;
Group = 0;
}
protected override CursOpt CursoringCompleteFlags()
{
CursOpt opt = base.CursoringCompleteFlags();
if (BadWeightCount == 0)
opt |= CursOpt.AllowBadWeights;
if (BadGroupCount == 0)
opt |= CursOpt.AllowBadGroups;
return opt;
}
public override bool Accept()
{
if (!base.Accept())
return false;
if (_getWeight != null)
{
_getWeight(ref Weight);
if (!_keepBadWeight && !(0 < Weight && Weight < float.PositiveInfinity))
{
BadWeightCount++;
return false;
}
}
if (_getGroup != null)
{
_getGroup(ref Group);
if (!_keepBadGroup && Group == 0)
{
BadGroupCount++;
return false;
}
}
if (_getId != null)
_getId(ref Id);
return true;
}
public sealed class Factory : FactoryBase<StandardScalarCursor>
{
public Factory(RoleMappedData data, CursOpt opt)
: base(data, opt)
{
}
protected override StandardScalarCursor CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)
=> new StandardScalarCursor(input, data, opt, signal);
}
}
/// <summary>
/// This derives from <see cref="StandardScalarCursor"/> and adds the feature column
/// as a <see cref="VBuffer{Float}"/>.
/// </summary>
[BestFriend]
internal class FeatureFloatVectorCursor : StandardScalarCursor
{
private readonly ValueGetter<VBuffer<float>> _get;
private readonly bool _keepBad;
public long BadFeaturesRowCount { get; private set; }
public VBuffer<float> Features;
public FeatureFloatVectorCursor(RoleMappedData data, CursOpt opt = CursOpt.Features,
Random rand = null, params int[] extraCols)
: this(CreateCursor(data, opt, rand, extraCols), data, opt)
{
}
protected FeatureFloatVectorCursor(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal = null)
: base(input, data, opt, signal)
{
if ((opt & CursOpt.Features) != 0 && data.Schema.Feature != null)
{
_get = Row.GetFeatureFloatVectorGetter(data);
_keepBad = (opt & CursOpt.AllowBadFeatures) != 0;
}
}
protected override CursOpt CursoringCompleteFlags()
{
var opt = base.CursoringCompleteFlags();
if (BadFeaturesRowCount == 0)
opt |= CursOpt.AllowBadFeatures;
return opt;
}
public override bool Accept()
{
if (!base.Accept())
return false;
if (_get != null)
{
_get(ref Features);
if (!_keepBad && !FloatUtils.IsFinite(Features.GetValues()))
{
BadFeaturesRowCount++;
return false;
}
}
return true;
}
public new sealed class Factory : FactoryBase<FeatureFloatVectorCursor>
{
public Factory(RoleMappedData data, CursOpt opt = CursOpt.Features)
: base(data, opt)
{
}
protected override FeatureFloatVectorCursor CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)
{
return new FeatureFloatVectorCursor(input, data, opt, signal);
}
}
}
/// <summary>
/// This derives from the FeatureFloatVectorCursor and adds the Label (float) column.
/// </summary>
[BestFriend]
internal class FloatLabelCursor : FeatureFloatVectorCursor
{
private readonly ValueGetter<float> _get;
private readonly bool _keepBad;
public long BadLabelCount { get; private set; }
public float Label;
public FloatLabelCursor(RoleMappedData data, CursOpt opt = CursOpt.Label,
Random rand = null, params int[] extraCols)
: this(CreateCursor(data, opt, rand, extraCols), data, opt)
{
}
protected FloatLabelCursor(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal = null)
: base(input, data, opt, signal)
{
if ((opt & CursOpt.Label) != 0 && data.Schema.Label != null)
{
_get = Row.GetLabelFloatGetter(data);
_keepBad = (opt & CursOpt.AllowBadLabels) != 0;
}
}
protected override CursOpt CursoringCompleteFlags()
{
var opt = base.CursoringCompleteFlags();
if (BadLabelCount == 0)
opt |= CursOpt.AllowBadLabels;
return opt;
}
public override bool Accept()
{
// Get the label first since base includes the features (the expensive part).
if (_get != null)
{
_get(ref Label);
if (!_keepBad && !FloatUtils.IsFinite(Label))
{
BadLabelCount++;
return false;
}
}
return base.Accept();
}
public new sealed class Factory : FactoryBase<FloatLabelCursor>
{
public Factory(RoleMappedData data, CursOpt opt = CursOpt.Label)
: base(data, opt)
{
}
protected override FloatLabelCursor CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)
{
return new FloatLabelCursor(input, data, opt, signal);
}
}
}
/// <summary>
/// This derives from the FeatureFloatVectorCursor and adds the Label (int) column,
/// enforcing multi-class semantics.
/// </summary>
[BestFriend]
internal class MulticlassLabelCursor : FeatureFloatVectorCursor
{
private readonly int _classCount;
private readonly ValueGetter<float> _get;
private readonly bool _keepBad;
public long BadLabelCount { get; private set; }
private float _raw;
public int Label;
public MulticlassLabelCursor(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label,
Random rand = null, params int[] extraCols)
: this(classCount, CreateCursor(data, opt, rand, extraCols), data, opt)
{
}
protected MulticlassLabelCursor(int classCount, DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal = null)
: base(input, data, opt, signal)
{
Contracts.Assert(classCount >= 0);
_classCount = classCount;
if ((opt & CursOpt.Label) != 0 && data.Schema.Label != null)
{
_get = Row.GetLabelFloatGetter(data);
_keepBad = (opt & CursOpt.AllowBadLabels) != 0;
}
}
protected override CursOpt CursoringCompleteFlags()
{
var opt = base.CursoringCompleteFlags();
if (BadLabelCount == 0)
opt |= CursOpt.AllowBadLabels;
return opt;
}
public override bool Accept()
{
// Get the label first since base includes the features (the expensive part).
if (_get != null)
{
_get(ref _raw);
Label = (int)_raw;
if (!_keepBad && !(Label == _raw && (0 <= _raw && (_raw < _classCount || _classCount == 0))))
{
BadLabelCount++;
return false;
}
}
return base.Accept();
}
public new sealed class Factory : FactoryBase<MulticlassLabelCursor>
{
private readonly int _classCount;
public Factory(int classCount, RoleMappedData data, CursOpt opt = CursOpt.Label)
: base(data, opt)
{
// Zero means that any non-negative integer value is fine.
Contracts.CheckParamValue(classCount >= 0, classCount, nameof(classCount), "Must be non-negative");
_classCount = classCount;
}
protected override MulticlassLabelCursor CreateCursorCore(DataViewRowCursor input, RoleMappedData data, CursOpt opt, Action<CursOpt> signal)
{
return new MulticlassLabelCursor(_classCount, input, data, opt, signal);
}
}
}
}
|