|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(MissingValueReplacingTransformer.Summary, typeof(IDataTransform), typeof(MissingValueReplacingTransformer), typeof(MissingValueReplacingTransformer.Options), typeof(SignatureDataTransform),
MissingValueReplacingTransformer.FriendlyName, MissingValueReplacingTransformer.LoadName, "NAReplace", MissingValueReplacingTransformer.ShortName, DocName = "transform/NAHandle.md")]
[assembly: LoadableClass(MissingValueReplacingTransformer.Summary, typeof(IDataTransform), typeof(MissingValueReplacingTransformer), null, typeof(SignatureLoadDataTransform),
MissingValueReplacingTransformer.FriendlyName, MissingValueReplacingTransformer.LoadName)]
[assembly: LoadableClass(MissingValueReplacingTransformer.Summary, typeof(MissingValueReplacingTransformer), null, typeof(SignatureLoadModel),
MissingValueReplacingTransformer.FriendlyName, MissingValueReplacingTransformer.LoadName)]
[assembly: LoadableClass(typeof(IRowMapper), typeof(MissingValueReplacingTransformer), null, typeof(SignatureLoadRowMapper),
MissingValueReplacingTransformer.FriendlyName, MissingValueReplacingTransformer.LoadName)]
namespace Microsoft.ML.Transforms
{
/// <summary>
/// <see cref="ITransformer"/> resulting from fitting a <see cref="MissingValueReplacingEstimator"/>.
/// </summary>
// REVIEW: May make sense to implement the transform template interface.
public sealed partial class MissingValueReplacingTransformer : OneToOneTransformerBase
{
private static readonly FuncInstanceMethodInfo1<MissingValueReplacingTransformer, DataViewType, Array, BitArray> _computeDefaultSlotsMethodInfo
= FuncInstanceMethodInfo1<MissingValueReplacingTransformer, DataViewType, Array, BitArray>.Create(target => target.ComputeDefaultSlots<int>);
internal enum ReplacementKind : byte
{
// REVIEW: What should the full list of options for this transform be?
DefaultValue = 0,
Mean = 1,
Minimum = 2,
Maximum = 3,
SpecifiedValue = 4,
Mode = 5,
[HideEnumValue]
Def = DefaultValue,
[HideEnumValue]
Default = DefaultValue,
[HideEnumValue]
Min = Minimum,
[HideEnumValue]
Max = Maximum,
[HideEnumValue]
Val = SpecifiedValue,
[HideEnumValue]
Value = SpecifiedValue,
}
// REVIEW: Need to add support for imputation modes for replacement values:
// *default: use default value
// *custom: use replacementValue string
// *mean: use domain value closest to the mean
// Potentially also min/max; probably will not include median due to its relatively low value and high computational cost.
// Note: Will need to support different replacement values for different slots to implement this.
internal sealed class Column : OneToOneColumn
{
// REVIEW: Should flexibility for different replacement values for slots be introduced?
[Argument(ArgumentType.AtMostOnce, HelpText = "Replacement value for NAs (uses default value if not given)", ShortName = "rep")]
public string ReplacementString;
[Argument(ArgumentType.AtMostOnce, HelpText = "The replacement method to utilize")]
public ReplacementKind? Kind;
// REVIEW: The default is to perform imputation by slot. If the input column is an unknown size vector type, then imputation
// will be performed across columns. Should the default be changed/an imputation method required?
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to impute values by slot")]
public bool? Slot;
internal static Column Parse(string str)
{
var res = new Column();
if (res.TryParse(str))
return res;
return null;
}
private protected override bool TryParse(string str)
{
// We accept N:R:S where N is the new column name, R is the replacement string,
// and S is source column names.
return base.TryParse(str, out ReplacementString);
}
internal bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
if (Kind != null || Slot != null)
return false;
if (ReplacementString == null)
return TryUnparseCore(sb);
return TryUnparseCore(sb, ReplacementString);
}
}
internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:rep:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
[Argument(ArgumentType.AtMostOnce, HelpText = "The replacement method to utilize", ShortName = "kind")]
public ReplacementKind ReplacementKind = (ReplacementKind)MissingValueReplacingEstimator.Defaults.Mode;
// Specifying by-slot imputation for vectors of unknown size will cause a warning, and the imputation will be global.
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to impute values by slot", ShortName = "slot")]
public bool ImputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot;
}
private static readonly FuncStaticMethodInfo1<DataViewType, string> _testTypeMethodInfo
= new FuncStaticMethodInfo1<DataViewType, string>(TestType<int>);
private static readonly FuncInstanceMethodInfo1<MissingValueReplacingTransformer, DataViewType, Delegate> _getIsNADelegateMethodInfo
= FuncInstanceMethodInfo1<MissingValueReplacingTransformer, DataViewType, Delegate>.Create(target => target.GetIsNADelegate<int>);
internal const string LoadName = "NAReplaceTransform";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
// REVIEW: temporary name
modelSignature: "NAREP TF",
// verWrittenCur: 0x00010001, // Initial
verWrittenCur: 0x0010002, // Added imputation methods.
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoadName,
loaderAssemblyName: typeof(MissingValueReplacingTransformer).Assembly.FullName);
}
internal const string Summary = "Create an output column of the same type and size of the input column, where missing values "
+ "are replaced with either the default value or the mean/min/max value (for non-text columns only).";
internal const string FriendlyName = "NA Replace Transform";
internal const string ShortName = "NARep";
internal static string TestType(DataViewType type)
{
// Item type must have an NA value that exists and is not equal to its default value.
var itemType = type.GetItemType();
return Utils.MarshalInvoke(_testTypeMethodInfo, itemType.RawType, itemType);
}
private static string TestType<T>(DataViewType type)
{
Contracts.Assert(type.GetItemType().RawType == typeof(T));
if (!Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(type.GetItemType(), out InPredicate<T> isNA))
{
return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value",
type, LoadName);
}
var t = default(T);
if (isNA(in t))
{
// REVIEW: Key values will be handled in a "new key value" transform.
return string.Format("Type '{0}' is not supported by {1} since its NA value is equivalent to its default value",
type, LoadName);
}
return null;
}
private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(MissingValueReplacingEstimator.ColumnOptions[] columns)
{
Contracts.CheckValue(columns, nameof(columns));
return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
}
// The output column types, parallel to Infos.
private readonly DataViewType[] _replaceTypes;
// The replacementValues for the columns, parallel to Infos.
// The elements of this array can be either primitive values or arrays of primitive values. When replacing a scalar valued column in Infos,
// this array will hold a primitive value. When replacing a vector valued column in Infos, this array will either hold a primitive
// value, indicating that NAs in all slots will be replaced with this value, or an array of primitives holding the value that each slot
// will have its NA values replaced with respectively. The case with an array of primitives can only occur when dealing with a
// vector of known size.
private readonly object[] _repValues;
// Marks if the replacement values in given slots of _repValues are the default value.
// REVIEW: Currently these arrays are constructed on load but could be changed to being constructed lazily.
private readonly BitArray[] _repIsDefault;
private protected override void CheckInputColumn(DataViewSchema inputSchema, int col, int srcCol)
{
var type = inputSchema[srcCol].Type;
string reason = TestType(type);
if (reason != null)
throw Host.ExceptParam(nameof(inputSchema), reason);
}
internal MissingValueReplacingTransformer(IHostEnvironment env, IDataView input, params MissingValueReplacingEstimator.ColumnOptions[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueReplacingTransformer)), GetColumnPairs(columns))
{
// Check that all the input columns are present and correct.
for (int i = 0; i < ColumnPairs.Length; i++)
{
if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].inputColumnName, out int srcCol))
throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].inputColumnName);
CheckInputColumn(input.Schema, i, srcCol);
}
GetReplacementValues(input, columns, out _repValues, out _repIsDefault, out _replaceTypes);
}
private MissingValueReplacingTransformer(IHost host, ModelLoadContext ctx)
: base(host, ctx)
{
var columnsLength = ColumnPairs.Length;
_repValues = new object[columnsLength];
_repIsDefault = new BitArray[columnsLength];
_replaceTypes = new DataViewType[columnsLength];
var saver = new BinarySaver(Host, new BinarySaver.Arguments());
for (int i = 0; i < columnsLength; i++)
{
if (!saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out DataViewType savedType, out object repValue))
throw Host.ExceptDecode();
_replaceTypes[i] = savedType;
if (savedType is VectorDataViewType savedVectorType)
{
// REVIEW: The current implementation takes the serialized VBuffer, densifies it, and stores the values array.
// It might be of value to consider storing the VBuffer in order to possibly benefit from sparsity. However, this would
// necessitate a reimplementation of the FillValues code to accommodate sparse VBuffers.
object[] args = new object[] { repValue, savedVectorType, i };
Func<VBuffer<int>, VectorDataViewType, int, int[]> func = GetValuesArray<int>;
var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(savedVectorType.ItemType.RawType);
_repValues[i] = meth.Invoke(this, args);
}
else
_repValues[i] = repValue;
Host.Assert(repValue.GetType() == _replaceTypes[i].RawType || repValue.GetType() == _replaceTypes[i].GetItemType().RawType);
}
}
private T[] GetValuesArray<T>(VBuffer<T> src, VectorDataViewType srcType, int iinfo)
{
Host.Assert(srcType != null);
Host.Assert(srcType.Size == src.Length);
VBufferUtils.Densify<T>(ref src);
InPredicate<T> defaultPred = Data.Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate<T>(srcType.ItemType);
_repIsDefault[iinfo] = new BitArray(srcType.Size);
var srcValues = src.GetValues();
for (int slot = 0; slot < srcValues.Length; slot++)
{
if (defaultPred(in srcValues[slot]))
_repIsDefault[iinfo][slot] = true;
}
// copy the result array out. Copying is OK because this method is only called on model load.
T[] valReturn = srcValues.ToArray();
Host.Assert(valReturn.Length == src.Length);
return valReturn;
}
/// <summary>
/// Fill the repValues array with the correct replacement values based on the user-given replacement kinds.
/// Vectors default to by-slot imputation unless otherwise specified, except for unknown sized vectors
/// which force across-slot imputation.
/// </summary>
private void GetReplacementValues(IDataView input, MissingValueReplacingEstimator.ColumnOptions[] columns, out object[] repValues, out BitArray[] slotIsDefault, out DataViewType[] types)
{
repValues = new object[columns.Length];
slotIsDefault = new BitArray[columns.Length];
types = new DataViewType[columns.Length];
var sources = new int[columns.Length];
ReplacementKind[] imputationModes = new ReplacementKind[columns.Length];
List<int> columnsToImpute = null;
// REVIEW: Would like to get rid of the sourceColumns list but seems to be the best way to provide
// the cursor with what columns to cursor through.
var sourceColumns = new List<DataViewSchema.Column>();
for (int iinfo = 0; iinfo < columns.Length; iinfo++)
{
input.Schema.TryGetColumnIndex(columns[iinfo].InputColumnName, out int colSrc);
sources[iinfo] = colSrc;
var type = input.Schema[colSrc].Type;
if (type is VectorDataViewType vectorType)
type = new VectorDataViewType(vectorType.ItemType, vectorType.Dimensions);
Delegate isNa = GetIsNADelegate(type);
types[iinfo] = type;
var kind = (ReplacementKind)columns[iinfo].Replacement;
switch (kind)
{
case ReplacementKind.SpecifiedValue:
repValues[iinfo] = GetSpecifiedValue(columns[iinfo].ReplacementString, _replaceTypes[iinfo], isNa);
break;
case ReplacementKind.DefaultValue:
repValues[iinfo] = GetDefault(type);
break;
case ReplacementKind.Mean:
case ReplacementKind.Minimum:
case ReplacementKind.Maximum:
case ReplacementKind.Mode:
if (!(type.GetItemType() is NumberDataViewType))
throw Host.Except("Cannot perform mean imputations on non-numeric '{0}'", type.GetItemType());
imputationModes[iinfo] = kind;
Utils.Add(ref columnsToImpute, iinfo);
sourceColumns.Add(input.Schema[colSrc]);
break;
default:
Host.Assert(false);
throw Host.Except("Internal error, undefined ReplacementKind '{0}' assigned in NAReplaceTransform.", columns[iinfo].Replacement);
}
}
// Exit if there are no columns needing a replacement value imputed.
if (Utils.Size(columnsToImpute) == 0)
return;
// Impute values.
using (var ch = Host.Start("Computing Statistics"))
using (var cursor = input.GetRowCursor(sourceColumns))
{
StatAggregator[] statAggregators = new StatAggregator[columnsToImpute.Count];
for (int ii = 0; ii < columnsToImpute.Count; ii++)
{
int iinfo = columnsToImpute[ii];
bool bySlot = columns[ii].ImputeBySlot;
if (types[iinfo] is VectorDataViewType vectorType && !vectorType.IsKnownSize && bySlot)
{
ch.Warning("By-slot imputation can not be done on variable-length column");
bySlot = false;
}
statAggregators[ii] = CreateStatAggregator(ch, types[iinfo], imputationModes[iinfo], bySlot,
cursor, sources[iinfo]);
}
while (cursor.MoveNext())
{
for (int ii = 0; ii < statAggregators.Length; ii++)
statAggregators[ii].ProcessRow();
}
for (int ii = 0; ii < statAggregators.Length; ii++)
repValues[columnsToImpute[ii]] = statAggregators[ii].GetStat();
}
// Construct the slotIsDefault bit arrays.
for (int ii = 0; ii < columnsToImpute.Count; ii++)
{
int slot = columnsToImpute[ii];
if (repValues[slot] is Array)
{
slotIsDefault[slot] = Utils.MarshalInvoke(_computeDefaultSlotsMethodInfo, this, types[slot].GetItemType().RawType, types[slot], (Array)repValues[slot]);
}
}
}
private BitArray ComputeDefaultSlots<T>(DataViewType type, Array values)
{
Host.Assert(values.Length == type.GetVectorSize());
BitArray defaultSlots = new BitArray(values.Length);
InPredicate<T> defaultPred = Data.Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate<T>(type.GetItemType());
T[] typedValues = (T[])values;
for (int slot = 0; slot < values.Length; slot++)
{
if (defaultPred(in typedValues[slot]))
defaultSlots[slot] = true;
}
return defaultSlots;
}
private object GetDefault(DataViewType type)
{
var rawType = type.GetItemType().RawType;
if (rawType.IsValueType)
return Activator.CreateInstance(rawType);
return null;
}
/// <summary>
/// Returns the isNA predicate for the respective type.
/// </summary>
private Delegate GetIsNADelegate(DataViewType type)
{
return Utils.MarshalInvoke(_getIsNADelegateMethodInfo, this, type.GetItemType().RawType, type);
}
private Delegate GetIsNADelegate<T>(DataViewType type)
=> Data.Conversion.Conversions.DefaultInstance.GetIsNAPredicate<T>(type.GetItemType());
/// <summary>
/// Converts a string to its respective value in the corresponding type.
/// </summary>
private object GetSpecifiedValue(string srcStr, DataViewType dstType, Delegate isNA)
{
Func<string, DataViewType, InPredicate<int>, object> func = GetSpecifiedValue<int>;
var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(dstType.GetItemType().RawType);
return meth.Invoke(this, new object[] { srcStr, dstType, isNA });
}
private object GetSpecifiedValue<T>(string srcStr, DataViewType dstType, InPredicate<T> isNA)
{
var val = default(T);
if (!string.IsNullOrEmpty(srcStr))
{
// Handles converting input strings to correct types.
var srcTxt = srcStr.AsMemory();
var strToT = Data.Conversion.Conversions.DefaultInstance.GetStandardConversion<ReadOnlyMemory<char>, T>(TextDataViewType.Instance, dstType.GetItemType(), out bool identity);
strToT(in srcTxt, ref val);
// Make sure that the srcTxt can legitimately be converted to dstType, throw error otherwise.
if (isNA(in val))
throw Contracts.Except("No conversion of '{0}' to '{1}'", srcStr, dstType.GetItemType());
}
return val;
}
// Factory method for SignatureDataTransform.
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
env.CheckValue(options.Columns, nameof(options.Columns));
var cols = new MissingValueReplacingEstimator.ColumnOptions[options.Columns.Length];
for (int i = 0; i < cols.Length; i++)
{
var item = options.Columns[i];
var kind = item.Kind ?? options.ReplacementKind;
if (!Enum.IsDefined(typeof(ReplacementKind), kind))
throw env.ExceptUserArg(nameof(options.ReplacementKind), "Undefined sorting criteria '{0}' detected for column '{1}'", kind, item.Name);
cols[i] = new MissingValueReplacingEstimator.ColumnOptions(
item.Name,
item.Source,
(MissingValueReplacingEstimator.ReplacementMode)(item.Kind ?? options.ReplacementKind),
item.Slot ?? options.ImputeBySlot,
item.ReplacementString);
}
return new MissingValueReplacingTransformer(env, input, cols).MakeDataTransform(input);
}
internal static IDataTransform Create(IHostEnvironment env, IDataView input, params MissingValueReplacingEstimator.ColumnOptions[] columns)
{
return new MissingValueReplacingTransformer(env, input, columns).MakeDataTransform(input);
}
// Factory method for SignatureLoadModel.
private static MissingValueReplacingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(LoadName);
host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new MissingValueReplacingTransformer(host, ctx);
}
// Factory method for SignatureLoadDataTransform.
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
=> Create(env, ctx).MakeDataTransform(input);
// Factory method for SignatureLoadRowMapper.
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
private VBuffer<T> CreateVBuffer<T>(T[] array)
{
Host.AssertValue(array);
return new VBuffer<T>(array.Length, array);
}
private void WriteTypeAndValue<T>(Stream stream, BinarySaver saver, DataViewType type, T rep)
{
Host.AssertValue(stream);
Host.AssertValue(saver);
Host.Assert(type.RawType == typeof(T) || type.GetItemType().RawType == typeof(T));
if (!saver.TryWriteTypeAndValue<T>(stream, type, ref rep, out int bytesWritten))
throw Host.Except("We do not know how to serialize terms of type '{0}'", type);
}
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
SaveColumns(ctx);
var saver = new BinarySaver(Host, new BinarySaver.Arguments());
for (int iinfo = 0; iinfo < _replaceTypes.Length; iinfo++)
{
var repValue = _repValues[iinfo];
var repType = _replaceTypes[iinfo].GetItemType();
if (_repIsDefault[iinfo] != null)
{
Host.Assert(repValue is Array);
Func<int[], VBuffer<int>> function = CreateVBuffer<int>;
var method = function.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repType.RawType);
repValue = method.Invoke(this, new object[] { _repValues[iinfo] });
repType = _replaceTypes[iinfo];
}
Host.Assert(!(repValue is Array));
object[] args = new object[] { ctx.Writer.BaseStream, saver, repType, repValue };
Action<Stream, BinarySaver, DataViewType, int> func = WriteTypeAndValue<int>;
Host.Assert(repValue.GetType() == _replaceTypes[iinfo].RawType || repValue.GetType() == _replaceTypes[iinfo].GetItemType().RawType);
var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repValue.GetType());
meth.Invoke(this, args);
}
}
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private sealed class ColInfo
{
public readonly string Name;
public readonly string InputColumnName;
public readonly DataViewType TypeSrc;
public ColInfo(string outputColumnName, string inputColumnName, DataViewType type)
{
Name = outputColumnName;
InputColumnName = inputColumnName;
TypeSrc = type;
}
}
private static readonly FuncInstanceMethodInfo1<Mapper, DataViewRow, int, Delegate> _composeGetterOneMethodInfo
= FuncInstanceMethodInfo1<Mapper, DataViewRow, int, Delegate>.Create(target => target.ComposeGetterOne<int>);
private static readonly FuncInstanceMethodInfo1<Mapper, DataViewRow, int, Delegate> _composeGetterVecMethodInfo
= FuncInstanceMethodInfo1<Mapper, DataViewRow, int, Delegate>.Create(target => target.ComposeGetterVec<int>);
private readonly MissingValueReplacingTransformer _parent;
private readonly ColInfo[] _infos;
private readonly DataViewType[] _types;
// The isNA delegates, parallel to Infos.
private readonly Delegate[] _isNAs;
public bool CanSaveOnnx(OnnxContext ctx) => true;
public Mapper(MissingValueReplacingTransformer parent, DataViewSchema inputSchema)
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
{
_parent = parent;
_infos = CreateInfos(inputSchema);
_types = new DataViewType[_parent.ColumnPairs.Length];
_isNAs = new Delegate[_parent.ColumnPairs.Length];
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
var type = _infos[i].TypeSrc;
VectorDataViewType vectorType = type as VectorDataViewType;
if (vectorType != null)
{
vectorType = new VectorDataViewType(vectorType.ItemType, vectorType.Dimensions);
type = vectorType;
}
var repType = _parent._repIsDefault[i] != null ? _parent._replaceTypes[i] : _parent._replaceTypes[i].GetItemType();
if (!type.GetItemType().Equals(repType.GetItemType()))
throw Host.ExceptParam(nameof(InputSchema), "Column '{0}' item type '{1}' does not match expected ColumnType of '{2}'",
_infos[i].InputColumnName, _parent._replaceTypes[i].GetItemType().ToString(), _infos[i].TypeSrc);
// If type is a vector and the value is not either a scalar or a vector of the same size, throw an error.
if (repType is VectorDataViewType repVectorType)
{
if (vectorType == null)
throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' item type '{1}' cannot be a vector when Columntype is a scalar of type '{2}'",
_infos[i].InputColumnName, repType, type);
if (!vectorType.IsKnownSize)
throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' is unknown size vector '{1}' must be a scalar instead of type '{2}'", _infos[i].InputColumnName, type, parent._replaceTypes[i]);
if (vectorType.Size != repVectorType.Size)
throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' item type '{1}' must be a scalar or a vector of the same size as Columntype '{2}'",
_infos[i].InputColumnName, repType, type);
}
_types[i] = type;
_isNAs[i] = _parent.GetIsNADelegate(type);
}
}
private ColInfo[] CreateInfos(DataViewSchema inputSchema)
{
Host.AssertValue(inputSchema);
var infos = new ColInfo[_parent.ColumnPairs.Length];
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colSrc))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName);
_parent.CheckInputColumn(inputSchema, i, colSrc);
var type = inputSchema[colSrc].Type;
infos[i] = new ColInfo(_parent.ColumnPairs[i].outputColumnName, _parent.ColumnPairs[i].inputColumnName, type);
}
return infos;
}
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex);
Host.Assert(colIndex >= 0);
var builder = new DataViewSchema.Annotations.Builder();
builder.Add(InputSchema[colIndex].Annotations, x => x == AnnotationUtils.Kinds.SlotNames || x == AnnotationUtils.Kinds.IsNormalized);
result[i] = new DataViewSchema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.ToAnnotations());
}
return result;
}
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
{
Host.AssertValue(input);
Host.Assert(0 <= iinfo && iinfo < _infos.Length);
disposer = null;
if (!(_infos[iinfo].TypeSrc is VectorDataViewType))
return ComposeGetterOne(input, iinfo);
return ComposeGetterVec(input, iinfo);
}
/// <summary>
/// Getter generator for single valued inputs.
/// </summary>
private Delegate ComposeGetterOne(DataViewRow input, int iinfo)
=> Utils.MarshalInvoke(_composeGetterOneMethodInfo, this, _infos[iinfo].TypeSrc.RawType, input, iinfo);
/// <summary>
/// Replaces NA values for scalars.
/// </summary>
private Delegate ComposeGetterOne<T>(DataViewRow input, int iinfo)
{
var getSrc = input.GetGetter<T>(input.Schema[ColMapNewToOld[iinfo]]);
var src = default(T);
var isNA = (InPredicate<T>)_isNAs[iinfo];
Host.Assert(_parent._repValues[iinfo] is T);
T rep = (T)_parent._repValues[iinfo];
ValueGetter<T> getter;
return getter =
(ref T dst) =>
{
getSrc(ref src);
dst = isNA(in src) ? rep : src;
};
}
/// <summary>
/// Getter generator for vector valued inputs.
/// </summary>
private Delegate ComposeGetterVec(DataViewRow input, int iinfo)
=> Utils.MarshalInvoke(_composeGetterVecMethodInfo, this, _infos[iinfo].TypeSrc.GetItemType().RawType, input, iinfo);
/// <summary>
/// Replaces NA values for vectors.
/// </summary>
private Delegate ComposeGetterVec<T>(DataViewRow input, int iinfo)
{
var getSrc = input.GetGetter<VBuffer<T>>(input.Schema[ColMapNewToOld[iinfo]]);
var isNA = (InPredicate<T>)_isNAs[iinfo];
var isDefault = Data.Conversion.Conversions.DefaultInstance.GetIsDefaultPredicate<T>(_infos[iinfo].TypeSrc.GetItemType());
var src = default(VBuffer<T>);
ValueGetter<VBuffer<T>> getter;
if (_parent._repIsDefault[iinfo] == null)
{
// One replacement value for all slots.
Host.Assert(_parent._repValues[iinfo] is T);
T rep = (T)_parent._repValues[iinfo];
bool repIsDefault = isDefault(in rep);
return getter =
(ref VBuffer<T> dst) =>
{
getSrc(ref src);
FillValues(in src, ref dst, isNA, rep, repIsDefault);
};
}
// Replacement values by slot.
Host.Assert(_parent._repValues[iinfo] is T[]);
// The replacement array.
T[] repArray = (T[])_parent._repValues[iinfo];
return getter =
(ref VBuffer<T> dst) =>
{
getSrc(ref src);
Host.Check(src.Length == repArray.Length);
FillValues(in src, ref dst, isNA, repArray, _parent._repIsDefault[iinfo]);
};
}
/// <summary>
/// Fills values for vectors where there is one replacement value.
/// </summary>
private void FillValues<T>(in VBuffer<T> src, ref VBuffer<T> dst, InPredicate<T> isNA, T rep, bool repIsDefault)
{
Host.AssertValue(isNA);
int srcSize = src.Length;
var srcValues = src.GetValues();
int srcCount = srcValues.Length;
// REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays
// does is over-allocate space if the replacement value is the default value in a dataset with a
// signficiant amount of NA values -- is it worth handling allocation of memory for this case?
var dstEditor = VBufferEditor.Create(ref dst, srcSize, srcCount);
int iivDst = 0;
if (src.IsDense)
{
// The source vector is dense.
Host.Assert(srcSize == srcCount);
for (int ivSrc = 0; ivSrc < srcCount; ivSrc++)
{
var srcVal = srcValues[ivSrc];
// The output for dense inputs is always dense.
// Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with
// the default value, resulting in more than half of the indices being the default value.
// In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision
// is it is not worth handling this case at the expense of running checks that will almost always not be triggered.
dstEditor.Values[ivSrc] = isNA(in srcVal) ? rep : srcVal;
}
iivDst = srcCount;
}
else
{
// The source vector is sparse.
Host.Assert(srcCount < srcSize);
var srcIndices = src.GetIndices();
// Note: ivPrev is only used for asserts.
int ivPrev = -1;
for (int iivSrc = 0; iivSrc < srcCount; iivSrc++)
{
Host.Assert(iivDst <= iivSrc);
var srcVal = srcValues[iivSrc];
int iv = srcIndices[iivSrc];
Host.Assert(ivPrev < iv && iv < srcSize);
ivPrev = iv;
if (!isNA(in srcVal))
{
dstEditor.Values[iivDst] = srcVal;
dstEditor.Indices[iivDst++] = iv;
}
else if (!repIsDefault)
{
// Allow for further sparsification.
dstEditor.Values[iivDst] = rep;
dstEditor.Indices[iivDst++] = iv;
}
}
Host.Assert(iivDst <= srcCount);
}
Host.Assert(0 <= iivDst);
Host.Assert(repIsDefault || iivDst == srcCount);
dst = dstEditor.CommitTruncated(iivDst);
}
/// <summary>
/// Fills values for vectors where there is slot-wise replacement values.
/// </summary>
private void FillValues<T>(in VBuffer<T> src, ref VBuffer<T> dst, InPredicate<T> isNA, T[] rep, BitArray repIsDefault)
{
Host.AssertValue(rep);
Host.Assert(rep.Length == src.Length);
Host.AssertValue(repIsDefault);
Host.Assert(repIsDefault.Length == src.Length);
Host.AssertValue(isNA);
int srcSize = src.Length;
var srcValues = src.GetValues();
int srcCount = srcValues.Length;
// REVIEW: One thing that changing the code to simply ensure that there are srcCount indices in the arrays
// does is over-allocate space if the replacement value is the default value in a dataset with a
// signficiant amount of NA values -- is it worth handling allocation of memory for this case?
var dstEditor = VBufferEditor.Create(ref dst, srcSize, srcCount);
int iivDst = 0;
if (src.IsDense)
{
// The source vector is dense.
Host.Assert(srcSize == srcCount);
for (int ivSrc = 0; ivSrc < srcCount; ivSrc++)
{
var srcVal = srcValues[ivSrc];
// The output for dense inputs is always dense.
// Note: Theoretically, one could imagine a dataset with NA values that one wished to replace with
// the default value, resulting in more than half of the indices being the default value.
// In this case, changing the dst vector to be sparse would be more memory efficient -- the current decision
// is it is not worth handling this case at the expense of running checks that will almost always not be triggered.
dstEditor.Values[ivSrc] = isNA(in srcVal) ? rep[ivSrc] : srcVal;
}
iivDst = srcCount;
}
else
{
// The source vector is sparse.
Host.Assert(srcCount < srcSize);
var srcIndices = src.GetIndices();
// Note: ivPrev is only used for asserts.
int ivPrev = -1;
for (int iivSrc = 0; iivSrc < srcCount; iivSrc++)
{
Host.Assert(iivDst <= iivSrc);
var srcVal = srcValues[iivSrc];
int iv = srcIndices[iivSrc];
Host.Assert(ivPrev < iv && iv < srcSize);
ivPrev = iv;
if (!isNA(in srcVal))
{
dstEditor.Values[iivDst] = srcVal;
dstEditor.Indices[iivDst++] = iv;
}
else if (!repIsDefault[iv])
{
// Allow for further sparsification.
dstEditor.Values[iivDst] = rep[iv];
dstEditor.Indices[iivDst++] = iv;
}
}
Host.Assert(iivDst <= srcCount);
}
Host.Assert(0 <= iivDst);
dst = dstEditor.CommitTruncated(iivDst);
}
public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
for (int iinfo = 0; iinfo < _infos.Length; ++iinfo)
{
ColInfo info = _infos[iinfo];
string inputColumnName = info.InputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
{
ctx.RemoveColumn(info.Name, false);
continue;
}
if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName),
ctx.AddIntermediateVariable(_parent._replaceTypes[iinfo], info.Name)))
{
ctx.RemoveColumn(info.Name, true);
}
}
}
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoadName);
Type rawType;
var type = _infos[iinfo].TypeSrc;
if (type is VectorDataViewType vectorType)
rawType = vectorType.ItemType.RawType;
else
rawType = type.RawType;
if (rawType != typeof(float))
return false;
string opType = "Imputer";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("replaced_value_float", Single.NaN);
if (!(_infos[iinfo].TypeSrc is VectorDataViewType))
node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_parent._repValues[iinfo], 1));
else
{
if (_parent._repIsDefault[iinfo] != null)
node.AddAttribute("imputed_value_floats", (float[])_parent._repValues[iinfo]);
else
node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_parent._repValues[iinfo], 1));
}
return true;
}
}
}
/// <summary>
/// <see cref="IEstimator{TTransformer}"/> for the <see cref="MissingValueReplacingTransformer"/>.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
///
/// ### Estimator Characteristics
/// | | |
/// | -- | -- |
/// | Does this estimator need to look at the data to train its parameters? | Yes |
/// | Input column data type | Vector or scalar of <xref:System.Single> or <xref:System.Double> |
/// | Output column data type | The same as the data type in the input column |
/// | Exportable to ONNX | Yes |
///
/// The resulting <xref:Microsoft.ML.Transforms.MissingValueReplacingTransformer"/> creates a new column, named as specified in the output column name parameters, and
/// copies the data from the input column to this new column with exception what missing values in data would be replaced according to chosen strategy.
///
/// Check the See Also section for links of usage examples.
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="ExtensionsCatalog.ReplaceMissingValues(TransformsCatalog, string, string, ReplacementMode, bool)" />
/// <seealso cref="ExtensionsCatalog.ReplaceMissingValues(TransformsCatalog, InputOutputColumnPair[], ReplacementMode, bool)" />
public sealed class MissingValueReplacingEstimator : IEstimator<MissingValueReplacingTransformer>
{
/// <summary>
/// The possible ways to replace missing values.
/// </summary>
public enum ReplacementMode : byte
{
/// <summary>
/// Replace with the default value of the column based on its type.
/// </summary>
DefaultValue = 0,
/// <summary>
/// Replace with the mean value of the column.
/// </summary>
Mean = 1,
/// <summary>
/// Replace with the minimum value of the column.
/// </summary>
Minimum = 2,
/// <summary>
/// Replace with the maximum value of the column.
/// </summary>
Maximum = 3,
/// <summary>
/// Replace with the most frequent value of the column.
/// </summary>
Mode = 5
}
[BestFriend]
internal static class Defaults
{
public const ReplacementMode Mode = ReplacementMode.DefaultValue;
public const bool ImputeBySlot = true;
}
/// <summary>
/// Describes how the transformer handles one column pair.
/// </summary>
[BestFriend]
internal sealed class ColumnOptions
{
/// <summary> Name of the column resulting from the transformation of <see cref="InputColumnName"/>.</summary>
public readonly string Name;
/// <summary> Name of column to transform. </summary>
public readonly string InputColumnName;
/// <summary>
/// If true, per-slot imputation of replacement is performed.
/// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors,
/// where imputation is always for the entire column.
/// </summary>
public readonly bool ImputeBySlot;
/// <summary> How to replace the missing values.</summary>
public readonly ReplacementMode Replacement;
/// <summary> Replacement value for missing values (only used in entrypoing and command line API).</summary>
internal readonly string ReplacementString;
/// <summary>
/// Describes how the transformer handles one column pair.
/// </summary>
/// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="name"/> will be used as source.</param>
/// <param name="replacementMode">How to replace the missing values.</param>
/// <param name="imputeBySlot">If true, per-slot imputation of replacement is performed.
/// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors,
/// where imputation is always for the entire column.</param>
public ColumnOptions(string name, string inputColumnName = null, ReplacementMode replacementMode = Defaults.Mode,
bool imputeBySlot = Defaults.ImputeBySlot)
{
Contracts.CheckNonWhiteSpace(name, nameof(name));
Name = name;
InputColumnName = inputColumnName ?? name;
ImputeBySlot = imputeBySlot;
Replacement = replacementMode;
}
/// <summary>
/// This constructor is used internally to convert from <see cref="MissingValueReplacingTransformer.Options"/> to <see cref="ColumnOptions"/>
/// as we support <paramref name="replacementString"/> in command line and entrypoint API only.
/// </summary>
internal ColumnOptions(string name, string inputColumnName, ReplacementMode replacementMode, bool imputeBySlot, string replacementString)
: this(name, inputColumnName, replacementMode, imputeBySlot)
{
ReplacementString = replacementString;
}
}
private readonly IHost _host;
private readonly ColumnOptions[] _columns;
internal MissingValueReplacingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, ReplacementMode replacementKind = Defaults.Mode)
: this(env, new ColumnOptions(outputColumnName, inputColumnName ?? outputColumnName, replacementKind))
{
}
[BestFriend]
internal MissingValueReplacingEstimator(IHostEnvironment env, params ColumnOptions[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(MissingValueReplacingEstimator));
_columns = columns;
}
/// <summary>
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
/// Used for schema propagation and verification in a pipeline.
/// </summary>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colInfo in _columns)
{
if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
string reason = MissingValueReplacingTransformer.TestType(col.ItemType);
if (reason != null)
throw _host.ExceptParam(nameof(inputSchema), reason);
var metadata = new List<SchemaShape.Column>();
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
metadata.Add(slotMeta);
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.IsNormalized, out var normalized))
metadata.Add(normalized);
var type = !(col.ItemType is VectorDataViewType vectorType) ?
col.ItemType :
new VectorDataViewType(vectorType.ItemType, vectorType.Dimensions);
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, col.Kind, type, false, new SchemaShape(metadata.ToArray()));
}
return new SchemaShape(result.Values);
}
/// <summary>
/// Trains and returns a <see cref="MissingValueReplacingTransformer"/>.
/// </summary>
public MissingValueReplacingTransformer Fit(IDataView input) => new MissingValueReplacingTransformer(_host, input, _columns);
}
}
|