|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
using Newtonsoft.Json.Linq;
[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), typeof(ColumnConcatenatingTransformer.TaggedOptions), typeof(SignatureDataTransform),
ColumnConcatenatingTransformer.UserName, ColumnConcatenatingTransformer.LoadName, "ConcatTransform", DocName = "transform/ConcatTransform.md")]
[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), null, typeof(SignatureLoadDataTransform),
ColumnConcatenatingTransformer.UserName, ColumnConcatenatingTransformer.LoaderSignature, ColumnConcatenatingTransformer.LoaderSignatureOld)]
[assembly: LoadableClass(typeof(ColumnConcatenatingTransformer), null, typeof(SignatureLoadModel),
ColumnConcatenatingTransformer.UserName, ColumnConcatenatingTransformer.LoaderSignature)]
[assembly: LoadableClass(typeof(IRowMapper), typeof(ColumnConcatenatingTransformer), null, typeof(SignatureLoadRowMapper),
ColumnConcatenatingTransformer.UserName, ColumnConcatenatingTransformer.LoaderSignature)]
namespace Microsoft.ML.Data
{
using PfaType = PfaUtils.Type;
/// <summary>
/// <see cref="ITransformer"/> resulting from fitting an <see cref="ColumnConcatenatingEstimator"/>.
/// </summary>
public sealed class ColumnConcatenatingTransformer : RowToRowTransformerBase
{
internal const string Summary = "Concatenates one or more columns of the same item type.";
internal const string UserName = "Concat Transform";
internal const string LoadName = "Concat";
internal const string LoaderSignature = "ConcatTransform";
internal const string LoaderSignatureOld = "ConcatFunction";
internal sealed class Column : ManyToOneColumn
{
internal static Column Parse(string str)
{
Contracts.AssertNonEmpty(str);
var res = new Column();
if (res.TryParse(str))
return res;
return null;
}
internal bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
return TryUnparseCore(sb);
}
}
[BestFriend]
internal sealed class TaggedColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the new column", ShortName = "name")]
public string Name;
// The tag here (the key of the KeyValuePair) is the string that will be the prefix of the slot name
// in the output column. For non-vector columns, the slot name will be either the column name or the
// tag if it is non empty. For vector columns, the slot names will be 'ColumnName.SlotName' if the
// tag is empty, 'Tag.SlotName' if tag is non empty, and simply the slot name if tag is non empty
// and equal to the column name.
[Argument(ArgumentType.Multiple, HelpText = "Names of the source columns", ShortName = "src")]
public KeyValuePair<string, string>[] Source;
internal static TaggedColumn Parse(string str)
{
Contracts.AssertNonEmpty(str);
// REVIEW: Support a short form for aliases.
var res = Column.Parse(str);
if (res == null)
return null;
Contracts.AssertValue(res.Source);
var taggedColumn = new TaggedColumn();
taggedColumn.Name = res.Name;
taggedColumn.Source = res.Source.Select(s => new KeyValuePair<string, string>(null, s)).ToArray();
return taggedColumn;
}
internal bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
if (Source == null || Source.Any(kvp => !string.IsNullOrEmpty(kvp.Key)))
return false;
var column = new Column();
column.Name = Name;
column.Source = Source.Select(kvp => kvp.Value).ToArray();
return column.TryUnparse(sb);
}
}
internal sealed class Options : TransformInputBase
{
public Options()
{
}
public Options(string name, params string[] source)
{
Columns = new[] { new Column()
{
Name = name,
Source = source
}};
}
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:srcs)",
Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
}
[BestFriend]
internal sealed class TaggedOptions
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)",
Name = "Column", ShortName = "col", SortOrder = 1)]
public TaggedColumn[] Columns;
}
[BestFriend]
internal sealed class ColumnOptions
{
public readonly string Name;
private readonly (string name, string alias)[] _sources;
public IReadOnlyList<(string name, string alias)> Sources => _sources.AsReadOnly();
/// <summary>
/// This denotes a concatenation of all <paramref name="inputColumnNames"/> into column called <paramref name="name"/>.
/// </summary>
public ColumnOptions(string name, params string[] inputColumnNames)
: this(name, GetPairs(inputColumnNames))
{
}
private static IEnumerable<(string name, string alias)> GetPairs(string[] inputColumnNames)
{
Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames));
return inputColumnNames.Select(name => (name, (string)null));
}
/// <summary>
/// This denotes a concatenation of input columns into one column called <paramref name="name"/>.
/// For each input column, an 'alias' can be specified, to be used in constructing the resulting slot names.
/// If the alias is not specified, it defaults to be column name.
/// </summary>
public ColumnOptions(string name, IEnumerable<(string name, string alias)> inputColumnNames)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames));
Contracts.CheckParam(inputColumnNames.Any(), nameof(inputColumnNames), "Can not be empty");
foreach (var (output, alias) in inputColumnNames)
{
Contracts.CheckNonEmpty(output, nameof(inputColumnNames));
Contracts.CheckValueOrNull(alias);
}
Name = name;
_sources = inputColumnNames.ToArray();
}
public void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
// *** Binary format ***
// int: id of output
// int: number of inputs
// for each input
// int: id of name
// int: id of alias
ctx.SaveNonEmptyString(Name);
Contracts.Assert(_sources.Length > 0);
ctx.Writer.Write(_sources.Length);
foreach (var (name, alias) in _sources)
{
ctx.SaveNonEmptyString(name);
ctx.SaveStringOrNull(alias);
}
}
internal ColumnOptions(ModelLoadContext ctx)
{
Contracts.AssertValue(ctx);
// *** Binary format ***
// int: id of output
// int: number of inputs
// for each input
// int: id of name
// int: id of alias
Name = ctx.LoadNonEmptyString();
int n = ctx.Reader.ReadInt32();
Contracts.CheckDecode(n > 0);
_sources = new (string name, string alias)[n];
for (int i = 0; i < n; i++)
{
var name = ctx.LoadNonEmptyString();
var alias = ctx.LoadStringOrNull();
_sources[i] = (name, alias);
}
}
}
private readonly ColumnOptions[] _columns;
/// <summary>
/// The names of the output and input column pairs for the transformation.
/// </summary>
internal IReadOnlyCollection<(string outputColumnName, string[] inputColumnNames)> Columns
=> _columns.Select(col => (outputColumnName: col.Name, inputColumnNames: col.Sources.Select(source => source.name).ToArray())).ToArray().AsReadOnly();
/// <summary>
/// Concatename columns in <paramref name="inputColumnNames"/> into one column <paramref name="outputColumnName"/>.
/// Original columns are also preserved.
/// The column types must match, and the output column type is always a vector.
/// </summary>
internal ColumnConcatenatingTransformer(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames)
: this(env, new ColumnOptions(outputColumnName, inputColumnNames))
{
}
/// <summary>
/// Concatenates multiple groups of columns, each group is denoted by one of <paramref name="columns"/>.
/// </summary>
internal ColumnConcatenatingTransformer(IHostEnvironment env, params ColumnOptions[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer)))
{
Contracts.CheckValue(columns, nameof(columns));
_columns = columns.ToArray();
}
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "CONCAT F",
//verWrittenCur: 0x00010001, // Initial
//verWrittenCur: 0x00010002, // Added aliases
verWrittenCur: 0x00010003, // Converted to transformer
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderSignatureAlt: LoaderSignatureOld,
loaderAssemblyName: typeof(ColumnConcatenatingTransformer).Assembly.FullName);
}
private const int VersionAddedAliases = 0x00010002;
private const int VersionTransformer = 0x00010003;
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
// int: number of columns
// for each column:
// columnOptions
Contracts.Assert(_columns.Length > 0);
ctx.Writer.Write(_columns.Length);
foreach (var col in _columns)
col.Save(ctx);
}
/// <summary>
/// Factory method for SignatureLoadModel.
/// </summary>
private ColumnConcatenatingTransformer(IHostEnvironment env, ModelLoadContext ctx) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer)))
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
if (ctx.Header.ModelVerWritten >= VersionTransformer)
{
// *** Binary format ***
// int: number of columns
// for each column:
// columnOptions
int n = ctx.Reader.ReadInt32();
Contracts.CheckDecode(n > 0);
_columns = new ColumnOptions[n];
for (int i = 0; i < n; i++)
_columns[i] = new ColumnOptions(ctx);
}
else
_columns = LoadLegacy(ctx);
}
private ColumnOptions[] LoadLegacy(ModelLoadContext ctx)
{
// *** Legacy binary format ***
// int: sizeof(Float).
// int: number of added columns
// for each added column
// int: id of output column name
// int: number of input column names
// int[]: ids of input column names
// if version >= VersionAddedAliases
// foreach column:
// foreach non-null alias
// int: index of the alias
// int: string id of the alias
// int: -1, marks the end of the list
var sizeofFloat = ctx.Reader.ReadInt32();
Contracts.CheckDecode(sizeofFloat == sizeof(float));
int n = ctx.Reader.ReadInt32();
Contracts.CheckDecode(n > 0);
var names = new string[n];
var inputs = new string[n][];
for (int i = 0; i < n; i++)
{
names[i] = ctx.LoadNonEmptyString();
int numSources = ctx.Reader.ReadInt32();
Contracts.CheckDecode(numSources > 0);
inputs[i] = new string[numSources];
for (int j = 0; j < numSources; j++)
inputs[i][j] = ctx.LoadNonEmptyString();
}
var aliases = new string[n][];
if (ctx.Header.ModelVerReadable >= VersionAddedAliases)
{
for (int i = 0; i < n; i++)
{
var length = inputs[i].Length;
aliases[i] = new string[length];
if (ctx.Header.ModelVerReadable >= VersionAddedAliases)
{
for (; ; )
{
var j = ctx.Reader.ReadInt32();
if (j == -1)
break;
Contracts.CheckDecode(0 <= j && j < length);
Contracts.CheckDecode(aliases[i][j] == null);
aliases[i][j] = ctx.LoadNonEmptyString();
}
}
}
}
var result = new ColumnOptions[n];
for (int i = 0; i < n; i++)
result[i] = new ColumnOptions(names[i],
inputs[i].Zip(aliases[i], (name, alias) => (name, alias)));
return result;
}
///<summary>
/// Factory method for SignatureDataTransform.
/// </summary>
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
env.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns));
for (int i = 0; i < options.Columns.Length; i++)
env.CheckUserArg(Utils.Size(options.Columns[i].Source) > 0, nameof(options.Columns));
var cols = options.Columns
.Select(c => new ColumnOptions(c.Name, c.Source))
.ToArray();
var transformer = new ColumnConcatenatingTransformer(env, cols);
return transformer.MakeDataTransform(input);
}
/// <summary>
/// Factory method corresponding to SignatureDataTransform.
/// </summary>
[BestFriend]
internal static IDataTransform Create(IHostEnvironment env, TaggedOptions options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
env.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns));
for (int i = 0; i < options.Columns.Length; i++)
env.CheckUserArg(Utils.Size(options.Columns[i].Source) > 0, nameof(options.Columns));
var cols = options.Columns
.Select(c => new ColumnOptions(c.Name, c.Source.Select(kvp => (kvp.Value, kvp.Key != "" ? kvp.Key : null))))
.ToArray();
var transformer = new ColumnConcatenatingTransformer(env, cols);
return transformer.MakeDataTransform(input);
}
private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);
/// <summary>
/// Factory method for SignatureLoadDataTransform.
/// </summary>
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
=> new ColumnConcatenatingTransformer(env, ctx).MakeDataTransform(input);
/// <summary>
/// Factory method for SignatureLoadRowMapper.
/// </summary>
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> new ColumnConcatenatingTransformer(env, ctx).MakeRowMapper(inputSchema);
private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa
{
private readonly ColumnConcatenatingTransformer _parent;
private readonly BoundColumn[] _columns;
public bool CanSaveOnnx(OnnxContext ctx) => true;
public bool CanSavePfa => true;
public Mapper(ColumnConcatenatingTransformer parent, DataViewSchema inputSchema) :
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
{
_parent = parent;
_columns = new BoundColumn[_parent._columns.Length];
for (int i = 0; i < _parent._columns.Length; i++)
{
_columns[i] = MakeColumn(inputSchema, i);
}
}
private BoundColumn MakeColumn(DataViewSchema inputSchema, int iinfo)
{
Contracts.AssertValue(inputSchema);
Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length);
DataViewType itemType = null;
int[] sources = new int[_parent._columns[iinfo].Sources.Count];
// Go through the columns, and establish the following:
// - indices of input columns in the input schema. Throw if they are not there.
// - output type. Throw if the types of inputs are not the same.
// - how many slots are there in the output vector (or variable). Denoted by totalSize.
// - total size of CategoricalSlotRanges metadata, if present. Denoted by catCount.
// - whether the column is normalized.
// It is true when ALL inputs are normalized (and of numeric type).
// - whether the column has slot names.
// It is true if ANY input is a scalar, or has slot names.
// - whether the column has categorical slot ranges.
// It is true if ANY input has this metadata.
int totalSize = 0;
int catCount = 0;
bool isNormalized = true;
bool hasSlotNames = false;
bool hasCategoricals = false;
for (int i = 0; i < _parent._columns[iinfo].Sources.Count; i++)
{
var (srcName, srcAlias) = _parent._columns[iinfo].Sources[i];
if (!inputSchema.TryGetColumnIndex(srcName, out int srcCol))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName);
sources[i] = srcCol;
var curType = inputSchema[srcCol].Type;
VectorDataViewType curVectorType = curType as VectorDataViewType;
DataViewType currentItemType = curVectorType?.ItemType ?? curType;
int currentValueCount = curVectorType?.Size ?? 1;
if (itemType == null)
{
itemType = currentItemType;
totalSize = currentValueCount;
}
else if (currentItemType.Equals(itemType))
{
// If any one input is variable length, then the output is variable length.
if (totalSize == 0 || currentValueCount == 0)
totalSize = 0;
else
totalSize += currentValueCount;
}
else
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, itemType.ToString(), curType.ToString());
if (isNormalized && !inputSchema[srcCol].IsNormalized())
isNormalized = false;
if (AnnotationUtils.TryGetCategoricalFeatureIndices(inputSchema, srcCol, out int[] typeCat))
{
Contracts.Assert(typeCat.Length > 0);
catCount += typeCat.Length;
hasCategoricals = true;
}
if ((!hasSlotNames && curVectorType == null)
|| (curVectorType != null && inputSchema[srcCol].HasSlotNames(curVectorType.Size)))
hasSlotNames = true;
}
if (!(itemType is NumberDataViewType))
isNormalized = false;
if (totalSize == 0)
{
hasCategoricals = false;
hasSlotNames = false;
}
return new BoundColumn(InputSchema, _parent._columns[iinfo], sources, new VectorDataViewType((PrimitiveDataViewType)itemType, totalSize),
isNormalized, hasSlotNames, hasCategoricals, totalSize, catCount);
}
/// <summary>
/// This represents the column information bound to the schema.
/// </summary>
private sealed class BoundColumn
{
private static readonly FuncInstanceMethodInfo1<BoundColumn, DataViewRow, Delegate> _makeIdentityGetterMethodInfo
= FuncInstanceMethodInfo1<BoundColumn, DataViewRow, Delegate>.Create(target => target.MakeIdentityGetter<int>);
private static readonly FuncInstanceMethodInfo1<BoundColumn, DataViewRow, Delegate> _makeGetterMethodInfo
= FuncInstanceMethodInfo1<BoundColumn, DataViewRow, Delegate>.Create(target => target.MakeGetter<int>);
public readonly int[] SrcIndices;
private readonly ColumnOptions _columnOptions;
private readonly DataViewType[] _srcTypes;
public readonly VectorDataViewType OutputType;
// Fields pertaining to column metadata.
private readonly bool _isIdentity;
private readonly bool _isNormalized;
private readonly bool _hasSlotNames;
private readonly bool _hasCategoricals;
private readonly VectorDataViewType _slotNamesType;
private readonly DataViewType _categoricalRangeType;
private readonly DataViewSchema _inputSchema;
public BoundColumn(DataViewSchema inputSchema, ColumnOptions columnOptions, int[] sources, VectorDataViewType outputType,
bool isNormalized, bool hasSlotNames, bool hasCategoricals, int slotCount, int catCount)
{
_columnOptions = columnOptions;
SrcIndices = sources;
_srcTypes = sources.Select(c => inputSchema[c].Type).ToArray();
OutputType = outputType;
_inputSchema = inputSchema;
_isIdentity = SrcIndices.Length == 1 && _inputSchema[SrcIndices[0]].Type is VectorDataViewType;
_isNormalized = isNormalized;
_hasSlotNames = hasSlotNames;
if (_hasSlotNames)
_slotNamesType = AnnotationUtils.GetNamesType(slotCount);
_hasCategoricals = hasCategoricals;
if (_hasCategoricals)
_categoricalRangeType = AnnotationUtils.GetCategoricalType(catCount / 2);
}
public DataViewSchema.DetachedColumn MakeSchemaColumn()
{
if (_isIdentity)
{
var inputCol = _inputSchema[SrcIndices[0]];
return new DataViewSchema.DetachedColumn(_columnOptions.Name, inputCol.Type, inputCol.Annotations);
}
var metadata = new DataViewSchema.Annotations.Builder();
if (_isNormalized)
metadata.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, (ValueGetter<bool>)GetIsNormalized);
if (_hasSlotNames)
metadata.AddSlotNames(_slotNamesType.Size, GetSlotNames);
if (_hasCategoricals)
metadata.Add(AnnotationUtils.Kinds.CategoricalSlotRanges, _categoricalRangeType, (ValueGetter<VBuffer<int>>)GetCategoricalSlotRanges);
return new DataViewSchema.DetachedColumn(_columnOptions.Name, OutputType, metadata.ToAnnotations());
}
private void GetIsNormalized(ref bool value) => value = _isNormalized;
private void GetCategoricalSlotRanges(ref VBuffer<int> dst)
{
List<int> allValues = new List<int>();
int slotCount = 0;
for (int i = 0; i < SrcIndices.Length; i++)
{
Contracts.Assert(_srcTypes[i].GetValueCount() > 0);
if (i > 0)
slotCount += _srcTypes[i - 1].GetValueCount();
if (AnnotationUtils.TryGetCategoricalFeatureIndices(_inputSchema, SrcIndices[i], out int[] values))
{
Contracts.Assert(values.Length > 0 && values.Length % 2 == 0);
for (int j = 0; j < values.Length; j++)
allValues.Add(values[j] + slotCount);
}
}
Contracts.Assert(allValues.Count > 0);
dst = new VBuffer<int>(allValues.Count, allValues.ToArray());
}
private void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
{
Contracts.Assert(!_isIdentity);
Contracts.Assert(OutputType.Size > 0);
Contracts.AssertValue(_slotNamesType);
Contracts.Assert(_slotNamesType.Size == OutputType.Size);
var bldr = BufferBuilder<ReadOnlyMemory<char>>.CreateDefault();
bldr.Reset(_slotNamesType.Size, dense: false);
var sb = new StringBuilder();
var names = default(VBuffer<ReadOnlyMemory<char>>);
int slot = 0;
for (int i = 0; i < _srcTypes.Length; i++)
{
int colSrc = SrcIndices[i];
var typeSrc = _srcTypes[i];
Contracts.Assert(_columnOptions.Sources[i].alias != "");
var colName = _inputSchema[colSrc].Name;
var nameSrc = _columnOptions.Sources[i].alias ?? colName;
if (!(typeSrc is VectorDataViewType vectorTypeSrc))
{
bldr.AddFeature(slot++, nameSrc.AsMemory());
continue;
}
Contracts.Assert(vectorTypeSrc.IsKnownSize);
VectorDataViewType typeNames = null;
var inputMetadata = _inputSchema[colSrc].Annotations;
if (inputMetadata != null && inputMetadata.Schema.TryGetColumnIndex(AnnotationUtils.Kinds.SlotNames, out int idx))
typeNames = inputMetadata.Schema[idx].Type as VectorDataViewType;
if (typeNames != null && typeNames.Size == vectorTypeSrc.Size && typeNames.ItemType is TextDataViewType)
{
inputMetadata.GetValue(AnnotationUtils.Kinds.SlotNames, ref names);
sb.Clear();
if (_columnOptions.Sources[i].alias != colName)
sb.Append(nameSrc).Append(".");
int len = sb.Length;
foreach (var kvp in names.Items())
{
if (kvp.Value.IsEmpty)
continue;
sb.Length = len;
sb.AppendMemory(kvp.Value);
bldr.AddFeature(slot + kvp.Key, sb.ToString().AsMemory());
}
}
slot += vectorTypeSrc.Size;
}
Contracts.Assert(slot == OutputType.Size);
bldr.GetResult(ref dst);
}
public Delegate MakeGetter(DataViewRow input)
{
if (_isIdentity)
return Utils.MarshalInvoke(_makeIdentityGetterMethodInfo, this, OutputType.RawType, input);
return Utils.MarshalInvoke(_makeGetterMethodInfo, this, OutputType.ItemType.RawType, input);
}
private Delegate MakeIdentityGetter<T>(DataViewRow input)
{
Contracts.Assert(SrcIndices.Length == 1);
return input.GetGetter<T>(input.Schema[SrcIndices[0]]);
}
private Delegate MakeGetter<T>(DataViewRow input)
{
var srcGetterOnes = new ValueGetter<T>[SrcIndices.Length];
var srcGetterVecs = new ValueGetter<VBuffer<T>>[SrcIndices.Length];
for (int j = 0; j < SrcIndices.Length; j++)
{
var column = input.Schema[SrcIndices[j]];
if (_srcTypes[j] is VectorDataViewType)
srcGetterVecs[j] = input.GetGetter<VBuffer<T>>(column);
else
srcGetterOnes[j] = input.GetGetter<T>(column);
}
T tmp = default(T);
VBuffer<T>[] tmpBufs = new VBuffer<T>[SrcIndices.Length];
ValueGetter<VBuffer<T>> result = (ref VBuffer<T> dst) =>
{
int dstLength = 0;
int dstCount = 0;
for (int i = 0; i < SrcIndices.Length; i++)
{
var type = _srcTypes[i];
if (type is VectorDataViewType vectorType)
{
srcGetterVecs[i](ref tmpBufs[i]);
if (vectorType.Size != 0 && vectorType.Size != tmpBufs[i].Length)
{
throw Contracts.Except("Column '{0}': expected {1} slots, but got {2}",
input.Schema[SrcIndices[i]].Name, vectorType.Size, tmpBufs[i].Length)
.MarkSensitive(MessageSensitivity.Schema);
}
dstLength = checked(dstLength + tmpBufs[i].Length);
dstCount = checked(dstCount + tmpBufs[i].GetValues().Length);
}
else
{
dstLength = checked(dstLength + 1);
dstCount = checked(dstCount + 1);
}
}
if (dstCount <= dstLength / 2)
{
// Concatenate into a sparse representation.
var editor = VBufferEditor.Create(ref dst, dstLength, dstCount);
int offset = 0;
int count = 0;
for (int j = 0; j < SrcIndices.Length; j++)
{
Contracts.Assert(offset < dstLength);
if (_srcTypes[j] is VectorDataViewType)
{
var buffer = tmpBufs[j];
var bufferValues = buffer.GetValues();
Contracts.Assert(bufferValues.Length <= dstCount - count);
Contracts.Assert(buffer.Length <= dstLength - offset);
if (buffer.IsDense)
{
for (int i = 0; i < bufferValues.Length; i++)
{
editor.Values[count] = bufferValues[i];
editor.Indices[count++] = offset + i;
}
}
else
{
var bufferIndices = buffer.GetIndices();
for (int i = 0; i < bufferValues.Length; i++)
{
editor.Values[count] = bufferValues[i];
editor.Indices[count++] = offset + bufferIndices[i];
}
}
offset += buffer.Length;
}
else
{
Contracts.Assert(count < dstCount);
srcGetterOnes[j](ref tmp);
editor.Values[count] = tmp;
editor.Indices[count++] = offset;
offset++;
}
}
Contracts.Assert(count <= dstCount);
Contracts.Assert(offset == dstLength);
dst = editor.CommitTruncated(count);
}
else
{
// Concatenate into a dense representation.
var editor = VBufferEditor.Create(ref dst, dstLength);
int offset = 0;
for (int j = 0; j < SrcIndices.Length; j++)
{
Contracts.Assert(tmpBufs[j].Length <= dstLength - offset);
if (_srcTypes[j] is VectorDataViewType)
{
tmpBufs[j].CopyTo(editor.Values, offset);
offset += tmpBufs[j].Length;
}
else
{
srcGetterOnes[j](ref tmp);
editor.Values[offset++] = tmp;
}
}
Contracts.Assert(offset == dstLength);
dst = editor.Commit();
}
};
return result;
}
public KeyValuePair<string, JToken> SavePfaInfo(BoundPfaContext ctx)
{
Contracts.AssertValue(ctx);
string outName = _columnOptions.Name;
if (!OutputType.IsKnownSize) // Do not attempt variable length.
return new KeyValuePair<string, JToken>(outName, null);
string[] srcTokens = new string[SrcIndices.Length];
bool[] srcPrimitive = new bool[SrcIndices.Length];
for (int i = 0; i < SrcIndices.Length; ++i)
{
var srcName = _columnOptions.Sources[i].name;
if ((srcTokens[i] = ctx.TokenOrNullForName(srcName)) == null)
return new KeyValuePair<string, JToken>(outName, null);
srcPrimitive[i] = _srcTypes[i] is PrimitiveDataViewType;
}
Contracts.Assert(srcTokens.All(tok => tok != null));
var itemColumnType = OutputType.ItemType;
var itemType = PfaType.PfaTypeOrNullForColumnType(itemColumnType);
if (itemType == null)
return new KeyValuePair<string, JToken>(outName, null);
JObject jobj = null;
var arrType = PfaType.Array(itemType);
// The "root" object will be the concatenation of all the initial scalar objects into an
// array, or else, if the first object is not scalar, just that first object.
JToken result;
int min;
if (srcPrimitive[0])
{
JArray rootObjects = new JArray();
for (int i = 0; i < srcTokens.Length && srcPrimitive[i]; ++i)
rootObjects.Add(srcTokens[i]);
result = jobj.AddReturn("type", arrType).AddReturn("new", new JArray(rootObjects));
min = rootObjects.Count;
}
else
{
result = srcTokens[0];
min = 1;
}
for (int i = min; i < srcTokens.Length; ++i)
result = PfaUtils.Call(srcPrimitive[i] ? "a.append" : "a.concat", result, srcTokens[i]);
Contracts.AssertValue(result);
return new KeyValuePair<string, JToken>(outName, result);
}
}
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
var active = new bool[InputSchema.Count];
for (int i = 0; i < _columns.Length; i++)
{
if (activeOutput(i))
{
foreach (var src in _columns[i].SrcIndices)
active[src] = true;
}
}
return col => active[col];
}
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() => _columns.Select(x => x.MakeSchemaColumn()).ToArray();
private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
{
disposer = null;
return _columns[iinfo].MakeGetter(input);
}
public void SaveAsPfa(BoundPfaContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
var toHide = new List<string>();
var toDeclare = new List<KeyValuePair<string, JToken>>();
for (int iinfo = 0; iinfo < _columns.Length; ++iinfo)
{
var toSave = _columns[iinfo].SavePfaInfo(ctx);
if (toSave.Value == null)
toHide.Add(toSave.Key);
else
toDeclare.Add(toSave);
}
ctx.Hide(toHide.ToArray());
ctx.DeclareVar(toDeclare.ToArray());
}
public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Contracts.Assert(CanSaveOnnx(ctx));
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
for (int iinfo = 0; iinfo < _columns.Length; ++iinfo)
{
var colInfo = _parent._columns[iinfo];
var boundCol = _columns[iinfo];
string outName = colInfo.Name;
var outColType = boundCol.OutputType;
if ((!outColType.IsKnownSize) || (!(outColType.GetItemType() is NumberDataViewType)))
{
ctx.RemoveColumn(outName, false);
continue;
}
List<KeyValuePair<string, long>> inputList = new List<KeyValuePair<string, long>>();
for (int i = 0; i < boundCol.SrcIndices.Length; ++i)
{
var srcName = colInfo.Sources[i].name;
if (!ctx.ContainsColumn(srcName))
{
ctx.RemoveColumn(outName, false);
return;
}
var srcIndex = boundCol.SrcIndices[i];
inputList.Add(new KeyValuePair<string, long>(ctx.GetVariableName(srcName),
InputSchema[srcIndex].Type.GetValueCount()));
}
string opType = "FeatureVectorizer";
int outVectorSize = (int)inputList.Sum(x => x.Value);
var vectorizerOutputType = new VectorDataViewType(NumberDataViewType.Single, outVectorSize);
var vectorizerOutputName = ctx.AddIntermediateVariable(vectorizerOutputType, "VectorFeaturizerOutput");
var node = ctx.CreateNode(opType, inputList.Select(t => t.Key),
new[] { vectorizerOutputName }, ctx.GetNodeName(opType));
node.AddAttribute("inputdimensions", inputList.Select(x => x.Value));
opType = "Cast";
var dstVectorType = new VectorDataViewType(outColType.GetItemType() as PrimitiveDataViewType, outVectorSize);
var dstVariableName = ctx.AddIntermediateVariable(dstVectorType, outName);
var castNode = ctx.CreateNode(opType, vectorizerOutputName, dstVariableName, ctx.GetNodeName(opType), "");
castNode.AddAttribute("to", outColType.ItemType.RawType);
}
}
}
}
}
|