|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Reflection;
using Microsoft.ML.Data.Conversion;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
{
/// <summary>
/// This applies the user provided ValueMapper to a column to produce a new column. It automatically
/// injects a standard conversion from the actual type of the source column to typeSrc (if needed).
/// </summary>
[BestFriend]
internal static class LambdaColumnMapper
{
// REVIEW: It would be nice to support propagation of select metadata.
public static IDataView Create<TSrc, TDst>(IHostEnvironment env, string name, IDataView input,
string src, string dst, DataViewType typeSrc, DataViewType typeDst, ValueMapper<TSrc, TDst> mapper,
ValueGetter<VBuffer<ReadOnlyMemory<char>>> keyValueGetter = null, ValueGetter<VBuffer<ReadOnlyMemory<char>>> slotNamesGetter = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonEmpty(name, nameof(name));
env.CheckValue(input, nameof(input));
env.CheckNonEmpty(src, nameof(src));
env.CheckNonEmpty(dst, nameof(dst));
env.CheckValue(typeSrc, nameof(typeSrc));
env.CheckValue(typeDst, nameof(typeDst));
env.CheckValue(mapper, nameof(mapper));
env.Check(keyValueGetter == null || typeDst.GetItemType() is KeyDataViewType);
env.Check(slotNamesGetter == null || typeDst.IsKnownSizeVector());
if (typeSrc.RawType != typeof(TSrc))
{
throw env.ExceptParam(nameof(mapper),
"The source column type '{0}' doesn't match the input type of the mapper", typeSrc);
}
if (typeDst.RawType != typeof(TDst))
{
throw env.ExceptParam(nameof(mapper),
"The destination column type '{0}' doesn't match the output type of the mapper", typeDst);
}
bool tmp = input.Schema.TryGetColumnIndex(src, out int colSrc);
if (!tmp)
throw env.ExceptParam(nameof(src), "The input data doesn't have a column named '{0}'", src);
var typeOrig = input.Schema[colSrc].Type;
// REVIEW: Ideally this should support vector-type conversion. It currently doesn't.
bool ident;
Delegate conv;
if (typeOrig.SameSizeAndItemType(typeSrc))
{
ident = true;
conv = null;
}
else if (!Conversions.DefaultInstance.TryGetStandardConversion(typeOrig, typeSrc, out conv, out ident))
{
throw env.ExceptParam(nameof(mapper),
"The type of column '{0}', '{1}', cannot be converted to the input type of the mapper '{2}'",
src, typeOrig, typeSrc);
}
var col = new Column(src, dst);
IDataView impl;
if (ident)
impl = new Impl<TSrc, TDst, TDst>(env, name, input, col, typeDst, mapper, keyValueGetter: keyValueGetter, slotNamesGetter: slotNamesGetter);
else
{
Func<IHostEnvironment, string, IDataView, Column, DataViewType, ValueMapper<int, int>,
ValueMapper<int, int>, ValueGetter<VBuffer<ReadOnlyMemory<char>>>, ValueGetter<VBuffer<ReadOnlyMemory<char>>>,
Impl<int, int, int>> del = CreateImpl<int, int, int>;
var meth = del.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeOrig.RawType, typeof(TSrc), typeof(TDst));
impl = (IDataView)meth.Invoke(null, new object[] { env, name, input, col, typeDst, conv, mapper, keyValueGetter, slotNamesGetter });
}
return new OpaqueDataView(impl);
}
private static Impl<T1, T2, T3> CreateImpl<T1, T2, T3>(
IHostEnvironment env, string name, IDataView input, Column col,
DataViewType typeDst, ValueMapper<T1, T2> map1, ValueMapper<T2, T3> map2,
ValueGetter<VBuffer<ReadOnlyMemory<char>>> keyValueGetter, ValueGetter<VBuffer<ReadOnlyMemory<char>>> slotNamesGetter)
{
return new Impl<T1, T2, T3>(env, name, input, col, typeDst, map1, map2, keyValueGetter);
}
private sealed class Column : OneToOneColumn
{
public Column(string src, string dst)
{
Name = dst;
Source = src;
}
}
private sealed class Impl<T1, T2, T3> : OneToOneTransformBase
{
private readonly DataViewType _typeDst;
private readonly ValueMapper<T1, T2> _map1;
private readonly ValueMapper<T2, T3> _map2;
public Impl(IHostEnvironment env, string name, IDataView input, OneToOneColumn col,
DataViewType typeDst, ValueMapper<T1, T2> map1, ValueMapper<T2, T3> map2 = null,
ValueGetter<VBuffer<ReadOnlyMemory<char>>> keyValueGetter = null, ValueGetter<VBuffer<ReadOnlyMemory<char>>> slotNamesGetter = null)
: base(env, name, new[] { col }, input, x => null)
{
Host.Assert(typeDst.RawType == typeof(T3));
Host.AssertValue(map1);
Host.Assert(map2 != null || typeof(T2) == typeof(T3));
_typeDst = typeDst;
_map1 = map1;
_map2 = map2;
if (keyValueGetter != null || slotNamesGetter != null)
{
using (var bldr = Metadata.BuildMetadata(0))
{
if (keyValueGetter != null)
{
AnnotationUtils.AnnotationGetter<VBuffer<ReadOnlyMemory<char>>> mdGetter =
(int c, ref VBuffer<ReadOnlyMemory<char>> dst) => keyValueGetter(ref dst);
bldr.AddGetter(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(TextDataViewType.Instance, _typeDst.GetItemType().GetKeyCountAsInt32(Host)), mdGetter);
}
if (slotNamesGetter != null)
{
int vectorSize = _typeDst.GetVectorSize();
Host.Assert(vectorSize > 0);
AnnotationUtils.AnnotationGetter<VBuffer<ReadOnlyMemory<char>>> mdGetter =
(int c, ref VBuffer<ReadOnlyMemory<char>> dst) => slotNamesGetter(ref dst);
bldr.AddGetter(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, vectorSize), mdGetter);
}
}
}
Metadata.Seal();
}
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.Assert(false, "Shouldn't serialize this!");
throw Host.ExceptNotSupp("Shouldn't serialize this");
}
protected override DataViewType GetColumnTypeCore(int iinfo)
{
Host.Assert(iinfo == 0);
return _typeDst;
}
protected override Delegate GetGetterCore(IChannel ch, DataViewRow input, int iinfo, out Action disposer)
{
Host.AssertValueOrNull(ch);
Host.AssertValue(input);
Host.Assert(iinfo == 0);
disposer = null;
if (_map2 == null)
{
var getSrc = GetSrcGetter<T1>(input, 0);
T1 v1 = default(T1);
ValueGetter<T2> getter =
(ref T2 v2) =>
{
getSrc(ref v1);
_map1(in v1, ref v2);
};
return getter;
}
else
{
var getSrc = GetSrcGetter<T1>(input, 0);
T1 v1 = default(T1);
T2 v2 = default(T2);
ValueGetter<T3> getter =
(ref T3 v3) =>
{
getSrc(ref v1);
_map1(in v1, ref v2);
_map2(in v2, ref v3);
};
return getter;
}
}
}
}
}
|