|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(MissingValueIndicatorTransformer.Summary, typeof(IDataTransform), typeof(MissingValueIndicatorTransformer), typeof(MissingValueIndicatorTransformer.Options), typeof(SignatureDataTransform),
MissingValueIndicatorTransformer.FriendlyName, MissingValueIndicatorTransformer.LoadName, "NAIndicator", MissingValueIndicatorTransformer.ShortName, DocName = "transform/NAHandle.md")]
[assembly: LoadableClass(MissingValueIndicatorTransformer.Summary, typeof(IDataTransform), typeof(MissingValueIndicatorTransformer), null, typeof(SignatureLoadDataTransform),
MissingValueIndicatorTransformer.FriendlyName, MissingValueIndicatorTransformer.LoadName)]
[assembly: LoadableClass(MissingValueIndicatorTransformer.Summary, typeof(MissingValueIndicatorTransformer), null, typeof(SignatureLoadModel),
MissingValueIndicatorTransformer.FriendlyName, MissingValueIndicatorTransformer.LoadName)]
[assembly: LoadableClass(typeof(IRowMapper), typeof(MissingValueIndicatorTransformer), null, typeof(SignatureLoadRowMapper),
MissingValueIndicatorTransformer.FriendlyName, MissingValueIndicatorTransformer.LoadName)]
namespace Microsoft.ML.Transforms
{
/// <summary>
/// <see cref="ITransformer"/> resulting from fitting a <see cref="MissingValueIndicatorEstimator"/>.
/// </summary>
public sealed class MissingValueIndicatorTransformer : OneToOneTransformerBase
{
internal sealed class Column : OneToOneColumn
{
internal static Column Parse(string str)
{
Contracts.AssertNonEmpty(str);
var res = new Column();
if (res.TryParse(str))
return res;
return null;
}
internal bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
return TryUnparseCore(sb);
}
}
internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
}
internal const string LoadName = "NaIndicatorTransform";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "NAIND TF",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoadName,
loaderAssemblyName: typeof(MissingValueIndicatorTransformer).Assembly.FullName);
}
internal const string Summary = "Create a boolean output column with the same number of slots as the input column, where the output value"
+ " is true if the value in the input column is missing.";
internal const string FriendlyName = "NA Indicator Transform";
internal const string ShortName = "NAInd";
private const string RegistrationName = nameof(MissingValueIndicatorTransformer);
/// <summary>
/// The names of the output and input column pairs for the transformation.
/// </summary>
internal IReadOnlyList<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();
/// <summary>
/// Initializes a new instance of <see cref="MissingValueIndicatorTransformer"/>
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="columns">The names of the input columns of the transformation and the corresponding names for the output columns.</param>
internal MissingValueIndicatorTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), columns)
{
}
internal MissingValueIndicatorTransformer(IHostEnvironment env, Options options)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), GetColumnPairs(options.Columns))
{
}
private MissingValueIndicatorTransformer(IHostEnvironment env, ModelLoadContext ctx)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), ctx)
{
Host.CheckValue(ctx, nameof(ctx));
}
private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(Column[] columns)
=> columns.Select(c => (c.Name, c.Source ?? c.Name)).ToArray();
// Factory method for SignatureLoadModel
internal static MissingValueIndicatorTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
ctx.CheckAtModel(GetVersionInfo());
return new MissingValueIndicatorTransformer(env, ctx);
}
// Factory method for SignatureDataTransform.
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
=> new MissingValueIndicatorTransformer(env, options).MakeDataTransform(input);
// Factory method for SignatureLoadDataTransform.
internal static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
=> Create(env, ctx).MakeDataTransform(input);
// Factory method for SignatureLoadRowMapper.
internal static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
/// <summary>
/// Saves the transform.
/// </summary>
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
SaveColumns(ctx);
}
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private static readonly FuncStaticMethodInfo1<DataViewType, Delegate> _getIsNADelegateMethodInfo
= new FuncStaticMethodInfo1<DataViewType, Delegate>(GetIsNADelegate<int>);
private static readonly FuncInstanceMethodInfo1<Mapper, DataViewRow, int, ValueGetter<bool>> _composeGetterOneMethodInfo
= FuncInstanceMethodInfo1<Mapper, DataViewRow, int, ValueGetter<bool>>.Create(target => target.ComposeGetterOne<int>);
private static readonly FuncInstanceMethodInfo1<Mapper, DataViewRow, int, ValueGetter<VBuffer<bool>>> _composeGetterVecMethodInfo
= FuncInstanceMethodInfo1<Mapper, DataViewRow, int, ValueGetter<VBuffer<bool>>>.Create(target => target.ComposeGetterVec<int>);
private readonly MissingValueIndicatorTransformer _parent;
private readonly ColInfo[] _infos;
private sealed class ColInfo
{
public readonly string Name;
public readonly string InputColumnName;
public readonly DataViewType OutputType;
public readonly DataViewType InputType;
public readonly Delegate InputIsNA;
public ColInfo(string name, string inputColumnName, DataViewType inType, DataViewType outType)
{
Name = name;
InputColumnName = inputColumnName;
InputType = inType;
OutputType = outType;
InputIsNA = GetIsNADelegate(InputType);
}
}
public Mapper(MissingValueIndicatorTransformer parent, DataViewSchema inputSchema)
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
{
_parent = parent;
_infos = CreateInfos(inputSchema);
}
private ColInfo[] CreateInfos(DataViewSchema inputSchema)
{
Host.AssertValue(inputSchema);
var infos = new ColInfo[_parent.ColumnPairs.Length];
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colSrc))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName);
_parent.CheckInputColumn(inputSchema, i, colSrc);
var inType = inputSchema[colSrc].Type;
DataViewType outType;
if (!(inType is VectorDataViewType vectorType))
outType = BooleanDataViewType.Instance;
else
outType = new VectorDataViewType(BooleanDataViewType.Instance, vectorType.Dimensions);
infos[i] = new ColInfo(_parent.ColumnPairs[i].outputColumnName, _parent.ColumnPairs[i].inputColumnName, inType, outType);
}
return infos;
}
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
for (int iinfo = 0; iinfo < _infos.Length; iinfo++)
{
InputSchema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int colIndex);
Host.Assert(colIndex >= 0);
var builder = new DataViewSchema.Annotations.Builder();
builder.Add(InputSchema[colIndex].Annotations, x => x == AnnotationUtils.Kinds.SlotNames);
ValueGetter<bool> getter = (ref bool dst) =>
{
dst = true;
};
builder.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, getter);
result[iinfo] = new DataViewSchema.DetachedColumn(_infos[iinfo].Name, _infos[iinfo].OutputType, builder.ToAnnotations());
}
return result;
}
/// <summary>
/// Returns the isNA predicate for the respective type.
/// </summary>
private static Delegate GetIsNADelegate(DataViewType type)
{
return Utils.MarshalInvoke(_getIsNADelegateMethodInfo, type.GetItemType().RawType, type);
}
private static Delegate GetIsNADelegate<T>(DataViewType type)
{
return Data.Conversion.Conversions.DefaultInstance.GetIsNAPredicate<T>(type.GetItemType());
}
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
{
Host.AssertValue(input);
Host.Assert(0 <= iinfo && iinfo < _infos.Length);
disposer = null;
if (!(_infos[iinfo].InputType is VectorDataViewType))
return ComposeGetterOne(input, iinfo);
return ComposeGetterVec(input, iinfo);
}
/// <summary>
/// Getter generator for single valued inputs.
/// </summary>
private ValueGetter<bool> ComposeGetterOne(DataViewRow input, int iinfo)
=> Utils.MarshalInvoke(_composeGetterOneMethodInfo, this, _infos[iinfo].InputType.RawType, input, iinfo);
private ValueGetter<bool> ComposeGetterOne<T>(DataViewRow input, int iinfo)
{
var getSrc = input.GetGetter<T>(input.Schema[ColMapNewToOld[iinfo]]);
var src = default(T);
var isNA = (InPredicate<T>)_infos[iinfo].InputIsNA;
ValueGetter<bool> getter;
return getter =
(ref bool dst) =>
{
getSrc(ref src);
dst = isNA(in src);
};
}
/// <summary>
/// Getter generator for vector valued inputs.
/// </summary>
private ValueGetter<VBuffer<bool>> ComposeGetterVec(DataViewRow input, int iinfo)
=> Utils.MarshalInvoke(_composeGetterVecMethodInfo, this, _infos[iinfo].InputType.GetItemType().RawType, input, iinfo);
private ValueGetter<VBuffer<bool>> ComposeGetterVec<T>(DataViewRow input, int iinfo)
{
var getSrc = input.GetGetter<VBuffer<T>>(input.Schema[ColMapNewToOld[iinfo]]);
var isNA = (InPredicate<T>)_infos[iinfo].InputIsNA;
var val = default(T);
var defaultIsNA = isNA(in val);
var src = default(VBuffer<T>);
var indices = new List<int>();
ValueGetter<VBuffer<bool>> getter;
return getter =
(ref VBuffer<bool> dst) =>
{
// Sense indicates if the values added to the indices list represent NAs or non-NAs.
bool sense;
getSrc(ref src);
FindNAs(in src, isNA, defaultIsNA, indices, out sense);
FillValues(src.Length, ref dst, indices, sense);
};
}
/// <summary>
/// Adds all NAs (or non-NAs) to the indices List. Whether NAs or non-NAs have been added is indicated by the bool sense.
/// </summary>
private void FindNAs<T>(in VBuffer<T> src, InPredicate<T> isNA, bool defaultIsNA, List<int> indices, out bool sense)
{
Host.AssertValue(isNA);
Host.AssertValue(indices);
// Find the indices of all of the NAs.
indices.Clear();
var srcValues = src.GetValues();
var srcCount = srcValues.Length;
if (src.IsDense)
{
for (int i = 0; i < srcCount; i++)
{
if (isNA(in srcValues[i]))
indices.Add(i);
}
sense = true;
}
else if (!defaultIsNA)
{
var srcIndices = src.GetIndices();
for (int ii = 0; ii < srcCount; ii++)
{
if (isNA(in srcValues[ii]))
indices.Add(srcIndices[ii]);
}
sense = true;
}
else
{
// Note that this adds non-NAs to indices -- this is indicated by sense being false.
var srcIndices = src.GetIndices();
for (int ii = 0; ii < srcCount; ii++)
{
if (!isNA(in srcValues[ii]))
indices.Add(srcIndices[ii]);
}
sense = false;
}
}
/// <summary>
/// Fills indicator values for vectors. The indices is a list that either holds all of the NAs or all
/// of the non-NAs, indicated by sense being true or false respectively.
/// </summary>
private void FillValues(int srcLength, ref VBuffer<bool> dst, List<int> indices, bool sense)
{
if (indices.Count == 0)
{
if (sense)
{
// Return empty VBuffer.
VBufferUtils.Resize(ref dst, srcLength, 0);
return;
}
// Return VBuffer filled with 1's.
var editor = VBufferEditor.Create(ref dst, srcLength);
for (int i = 0; i < srcLength; i++)
editor.Values[i] = true;
dst = editor.Commit();
return;
}
if (sense && indices.Count < srcLength / 2)
{
// Will produce sparse output.
int dstCount = indices.Count;
var editor = VBufferEditor.Create(ref dst, srcLength, dstCount);
indices.CopyTo(editor.Indices);
for (int ii = 0; ii < dstCount; ii++)
editor.Values[ii] = true;
Host.Assert(dstCount <= srcLength);
dst = editor.Commit();
}
else if (!sense && srcLength - indices.Count < srcLength / 2)
{
// Will produce sparse output.
int dstCount = srcLength - indices.Count;
var editor = VBufferEditor.Create(ref dst, srcLength, dstCount);
// Appends the length of the src to make the loop simpler,
// as the length of src will never be reached in the loop.
indices.Add(srcLength);
int iiDst = 0;
int iiSrc = 0;
int iNext = indices[iiSrc];
for (int i = 0; i < srcLength; i++)
{
Host.Assert(0 <= i && i <= iNext);
Host.Assert(iiSrc + iiDst == i);
if (i < iNext)
{
Host.Assert(iiDst < dstCount);
editor.Values[iiDst] = true;
editor.Indices[iiDst++] = i;
}
else
{
Host.Assert(iiSrc + 1 < indices.Count);
Host.Assert(iNext < indices[iiSrc + 1]);
iNext = indices[++iiSrc];
}
}
Host.Assert(srcLength == iiSrc + iiDst);
Host.Assert(iiDst == dstCount);
dst = editor.Commit();
}
else
{
// Will produce dense output.
var editor = VBufferEditor.Create(ref dst, srcLength);
// Appends the length of the src to make the loop simpler,
// as the length of src will never be reached in the loop.
indices.Add(srcLength);
int ii = 0;
for (int i = 0; i < srcLength; i++)
{
Host.Assert(0 <= i && i <= indices[ii]);
if (i == indices[ii])
{
editor.Values[i] = sense;
ii++;
Host.Assert(ii < indices.Count);
Host.Assert(indices[ii - 1] < indices[ii]);
}
else
editor.Values[i] = !sense;
}
dst = editor.Commit();
}
}
public bool CanSaveOnnx(OnnxContext ctx) => true;
public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
for (int iinfo = 0; iinfo < _infos.Length; ++iinfo)
{
ColInfo info = _infos[iinfo];
string inputColumnName = info.InputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
{
ctx.RemoveColumn(info.Name, false);
continue;
}
if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName),
ctx.AddIntermediateVariable(_infos[iinfo].OutputType, info.Name)))
{
ctx.RemoveColumn(info.Name, true);
}
}
}
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoadName);
var inputType = _infos[iinfo].InputType;
Type rawType = (inputType is VectorDataViewType vectorType) ? vectorType.ItemType.RawType : inputType.RawType;
if (rawType != typeof(float))
return false;
string opType;
opType = "IsNaN";
var isNaNOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsNaNOutput", true);
var nanNode = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
return true;
}
}
}
/// <summary>
/// <see cref="IEstimator{TTransformer}"/> for the <see cref="MissingValueIndicatorTransformer"/>.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// ### Estimator Characteristics
/// | | |
/// | -- | -- |
/// | Does this estimator need to look at the data to train its parameters? | No |
/// | Input column data type | Vector or scalar value of <xref:System.Single> or <xref:System.Double> |
/// | Output column data type | If input column was scalar then <xref:System.Boolean> otherwise vector of <xref:System.Boolean>. |
/// | Exportable to ONNX | Yes |
///
/// The resulting <xref:Microsoft.ML.Transforms.MissingValueIndicatorTransformer> creates a new column, named as specified in the output column name parameters, and
/// fills it with vector of bools where `true` in the i-th position in array indicates the i-th element in input column has missing value and `false` otherwise.
///
/// Check the See Also section for links to usage examples.
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="ExtensionsCatalog.IndicateMissingValues(TransformsCatalog, string, string)" />
/// <seealso cref="ExtensionsCatalog.IndicateMissingValues(TransformsCatalog, InputOutputColumnPair[])" />
public sealed class MissingValueIndicatorEstimator : TrivialEstimator<MissingValueIndicatorTransformer>
{
/// <summary>
/// Initializes a new instance of <see cref="MissingValueIndicatorEstimator"/>
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="columns">The names of the input columns of the transformation and the corresponding names for the output columns.</param>
internal MissingValueIndicatorEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), new MissingValueIndicatorTransformer(env, columns))
{
Contracts.CheckValue(env, nameof(env));
}
/// <summary>
/// Initializes a new instance of <see cref="MissingValueIndicatorEstimator"/>
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
internal MissingValueIndicatorEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null)
: this(env, (outputColumnName, inputColumnName ?? outputColumnName))
{
}
/// <summary>
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
/// Used for schema propagation and verification in a pipeline.
/// </summary>
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colPair in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col) || !Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(col.ItemType, out Delegate del))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName);
var metadata = new List<SchemaShape.Column>();
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
metadata.Add(slotMeta);
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
DataViewType type = !(col.ItemType is VectorDataViewType vectorType) ?
(DataViewType)BooleanDataViewType.Instance :
new VectorDataViewType(BooleanDataViewType.Instance, vectorType.Dimensions);
result[colPair.outputColumnName] = new SchemaShape.Column(colPair.outputColumnName, col.Kind, type, false, new SchemaShape(metadata.ToArray()));
}
return new SchemaShape(result.Values);
}
}
}
|