|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
{
[BestFriend]
internal static class InvertHashUtils
{
/// <summary>
/// Clears a destination StringBuilder. If it is currently null, allocates it.
/// </summary>
private static void ClearDst(ref StringBuilder dst)
{
Contracts.AssertValueOrNull(dst);
if (dst == null)
dst = new StringBuilder();
else
dst.Clear();
}
/// <summary>
/// Gets the mapping from T into a StringBuilder representation, using various heuristics.
/// This StringBuilder representation will be a component of the composed KeyValues for the
/// hash outputs.
/// </summary>
public static ValueMapper<T, StringBuilder> GetSimpleMapper<T>(DataViewSchema schema, int col)
{
Contracts.AssertValue(schema);
Contracts.Assert(0 <= col && col < schema.Count);
var type = schema[col].Type.GetItemType();
Contracts.Assert(type.RawType == typeof(T));
var conv = Conversion.Conversions.DefaultInstance;
// First: if not key, then get the standard string conversion.
if (!(type is KeyDataViewType keyType))
return conv.GetStringConversion<T>(type);
bool identity;
// Second choice: if key, utilize the KeyValues metadata for that key, if it has one and is text.
if (schema[col].HasKeyValues())
{
// REVIEW: Non-textual KeyValues are certainly possible. Should we handle them?
// Get the key names.
VBuffer<ReadOnlyMemory<char>> keyValues = default;
schema[col].GetKeyValues(ref keyValues);
ReadOnlyMemory<char> value = default;
// REVIEW: We could optimize for identity, but it's probably not worthwhile.
var keyMapper = conv.GetStandardConversion<T, uint>(type, NumberDataViewType.UInt32, out identity);
return
(in T src, ref StringBuilder dst) =>
{
ClearDst(ref dst);
uint intermediate = 0;
keyMapper(in src, ref intermediate);
if (intermediate == 0)
return;
keyValues.GetItemOrDefault((int)(intermediate - 1), ref value);
dst.AppendMemory(value);
};
}
// Third choice: just use the key value itself, subject to offsetting by the min.
return conv.GetKeyStringConversion<T>(keyType);
}
public static ValueMapper<KeyValuePair<int, T>, StringBuilder> GetPairMapper<T>(ValueMapper<T, StringBuilder> submap)
{
StringBuilder sb = null;
char[] buffer = null;
return
(in KeyValuePair<int, T> pair, ref StringBuilder dst) =>
{
ClearDst(ref dst);
dst.Append(pair.Key);
dst.Append(':');
var subval = pair.Value;
submap(in subval, ref sb);
AppendToEnd(sb, dst, ref buffer);
};
}
public static void AppendToEnd(StringBuilder src, StringBuilder dst, ref char[] buffer)
{
// A direct sb -> sb copy sure would be nice...
if (Utils.Size(src) > 0)
{
Utils.EnsureSize(ref buffer, src.Length);
src.CopyTo(0, buffer, 0, src.Length);
dst.Append(buffer, 0, src.Length);
}
}
}
[BestFriend]
internal sealed class InvertHashCollector<T>
{
/// <summary>
/// This is a small struct that is meant to compare akin to the value,
/// but also maintain the order in which it was inserted, assuming that
/// we're using something like a hashset where order is not preserved.
/// </summary>
private readonly struct Pair
{
public readonly T Value;
public readonly int Order;
public Pair(T value, int order)
{
Contracts.Assert(order >= 0);
Value = value;
Order = order;
}
}
private sealed class PairEqualityComparer : IEqualityComparer<Pair>
{
private readonly IEqualityComparer<T> _tComparer;
public PairEqualityComparer(IEqualityComparer<T> tComparer)
{
_tComparer = tComparer;
}
public bool Equals(Pair x, Pair y)
{
return _tComparer.Equals(x.Value, y.Value);
}
public int GetHashCode(Pair obj)
{
return _tComparer.GetHashCode(obj.Value);
}
}
// The maximum number of distinct keys to accumulate per slot.
private readonly int _maxCount;
// The maximum number of slots.
private readonly int _slots;
private readonly ValueMapper<T, StringBuilder> _stringifyMapper;
// REVIEW: The following is very general but inefficient. If perf is a problem, then this
// is one clear place where it should be helped.
private readonly Dictionary<int, HashSet<Pair>> _slotToValueSet;
private readonly IEqualityComparer<Pair> _comparer;
private readonly ValueMapper<T, T> _copier;
/// <summary>
/// Constructs an invert hash collector that collects unique keys per slot, then is able
/// to build a textual description out of that.
/// </summary>
/// <param name="slots">The maximum number of slots</param>
/// <param name="maxCount">The number of distinct keys we can accumulate per slot</param>
/// <param name="mapper">Utilized in composing the final description, once we have done
/// collecting the distinct keys.</param>
/// <param name="comparer">For detecting uniqueness of the keys we're collecting per slot.</param>
/// <param name="copier">For copying input values into a value to actually store. Useful for
/// types of objects where it is possible to do a comparison relatively quickly on some sort
/// of "unsafe" object, but for which when we decide to actually store it we need to provide
/// a "safe" version of the object. Utilized in the n-gram hash transform, for example.</param>
public InvertHashCollector(int slots, int maxCount, ValueMapper<T, StringBuilder> mapper,
IEqualityComparer<T> comparer, ValueMapper<T, T> copier = null)
{
Contracts.Assert(slots > 0);
Contracts.Assert(maxCount > 0);
Contracts.AssertValue(mapper);
Contracts.AssertValue(comparer);
_slots = slots;
_maxCount = maxCount;
_stringifyMapper = mapper;
_comparer = new PairEqualityComparer(comparer);
_slotToValueSet = new Dictionary<int, HashSet<Pair>>();
_copier = copier ?? ((in T src, ref T dst) => dst = src);
}
private ReadOnlyMemory<char> Textify(ref StringBuilder sb, ref StringBuilder temp, ref char[] cbuffer, ref Pair[] buffer, HashSet<Pair> pairs)
{
Contracts.AssertValueOrNull(sb);
Contracts.AssertValueOrNull(temp);
Contracts.AssertValueOrNull(cbuffer);
Contracts.AssertValueOrNull(buffer);
Contracts.Assert(Utils.Size(pairs) > 0);
int count = pairs.Count;
// Keep things in the same order they were inserted, by sorting on order.
Utils.EnsureSize(ref buffer, count);
pairs.CopyTo(buffer);
pairs.Clear();
// Optimize the one value case, where we don't have to use the string builder.
if (count == 1)
{
var value = buffer[0].Value;
_stringifyMapper(in value, ref temp);
return Utils.Size(temp) > 0 ? temp.ToString().AsMemory() : String.Empty.AsMemory();
}
Array.Sort(buffer, 0, count, Comparer<Pair>.Create((x, y) => x.Order - y.Order));
if (sb == null)
sb = new StringBuilder();
Contracts.Assert(sb.Length == 0);
// The more general collision case.
sb.Append('{');
for (int i = 0; i < count; ++i)
{
var pair = buffer[i];
if (i > 0)
sb.Append(',');
var value = pair.Value;
_stringifyMapper(in value, ref temp);
InvertHashUtils.AppendToEnd(temp, sb, ref cbuffer);
}
sb.Append('}');
var retval = sb.ToString().AsMemory();
sb.Clear();
return retval;
}
public VBuffer<ReadOnlyMemory<char>> GetMetadata()
{
int count = _slotToValueSet.Count;
Contracts.Assert(count <= _slots);
StringBuilder sb = null;
StringBuilder temp = null;
Pair[] pairs = null;
char[] cbuffer = null;
bool sparse = count <= _slots / 2;
if (sparse)
{
// Sparse
var indices = new int[count];
var values = new ReadOnlyMemory<char>[count];
int i = 0;
foreach (var p in _slotToValueSet)
{
Contracts.Assert(0 <= p.Key && p.Key < _slots);
indices[i] = p.Key;
values[i++] = Textify(ref sb, ref temp, ref cbuffer, ref pairs, p.Value);
}
Contracts.Assert(i == count);
Array.Sort(indices, values);
return new VBuffer<ReadOnlyMemory<char>>((int)_slots, count, values, indices);
}
else
{
// Dense
var values = new ReadOnlyMemory<char>[_slots];
foreach (var p in _slotToValueSet)
{
Contracts.Assert(0 <= p.Key && p.Key < _slots);
values[p.Key] = Textify(ref sb, ref temp, ref cbuffer, ref pairs, p.Value);
}
return new VBuffer<ReadOnlyMemory<char>>(values.Length, values);
}
}
public void Add(int dstSlot, ValueGetter<T> getter, ref T key)
{
// REVIEW: I only call the getter if I determine I have to, but
// at the cost of passing along this getter and ref argument (as opposed
// to just the argument). Is this really appropriate or helpful?
Contracts.Assert(0 <= dstSlot && dstSlot < _slots);
HashSet<Pair> pairSet;
if (_slotToValueSet.TryGetValue(dstSlot, out pairSet))
{
if (pairSet.Count >= _maxCount)
return;
}
else
pairSet = _slotToValueSet[dstSlot] = new HashSet<Pair>(_comparer);
getter(ref key);
pairSet.Add(new Pair(key, pairSet.Count));
}
public void Add(int dstSlot, T key)
{
Contracts.Assert(0 <= dstSlot && dstSlot < _slots);
HashSet<Pair> pairSet;
if (_slotToValueSet.TryGetValue(dstSlot, out pairSet))
{
if (pairSet.Count >= _maxCount)
return;
}
else
pairSet = _slotToValueSet[dstSlot] = new HashSet<Pair>(_comparer);
T dst = default(T);
_copier(in key, ref dst);
pairSet.Add(new Pair(dst, pairSet.Count));
}
public void Add(uint hash, ValueGetter<T> getter, ref T key)
{
// Convenience method for those where the inserters work in the hash space, not the
// slot space, assuming that hash value of 0 gets no key.
if (hash == 0)
return;
Add((int)hash - 1, getter, ref key);
}
public void Add(uint hash, T key)
{
if (hash == 0)
return;
Add((int)hash - 1, key);
}
}
/// <summary>
/// Simple utility class for saving a <see cref="VBuffer{T}"/> of ReadOnlyMemory
/// as a model, both in a binary and more easily human readable form.
/// </summary>
[BestFriend]
internal static class TextModelHelper
{
private const string LoaderSignature = "TextSpanBuffer";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "TEXTSPBF",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TextModelHelper).Assembly.FullName);
}
private static void Load(IChannel ch, ModelLoadContext ctx, CodecFactory factory, ref VBuffer<ReadOnlyMemory<char>> values)
{
Contracts.AssertValue(ch);
ch.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
// *** Binary format ***
// Codec parameterization: A codec parameterization that should be a ReadOnlyMemory codec
// int: n, the number of bytes used to write the values
// byte[n]: As encoded using the codec
// Get the codec from the factory, and from the stream. We have to
// attempt to read the codec from the stream, since codecs can potentially
// be versioned based on their parameterization.
IValueCodec codec;
// This *could* happen if we have an old version attempt to read a new version.
// Enabling this sort of binary classification is why we also need to write the
// codec specification.
if (!factory.TryReadCodec(ctx.Reader.BaseStream, out codec))
throw ch.ExceptDecode();
ch.AssertValue(codec);
if (!(codec.Type is VectorDataViewType vectorType))
throw ch.ExceptDecode();
ch.CheckDecode(vectorType.ItemType is TextDataViewType);
var textCodec = (IValueCodec<VBuffer<ReadOnlyMemory<char>>>)codec;
var bufferLen = ctx.Reader.ReadInt32();
ch.CheckDecode(bufferLen >= 0);
using (var stream = new SubsetStream(ctx.Reader.BaseStream, bufferLen))
{
using (var reader = textCodec.OpenReader(stream, 1))
{
reader.MoveNext();
values = default(VBuffer<ReadOnlyMemory<char>>);
reader.Get(ref values);
}
ch.CheckDecode(stream.ReadByte() == -1);
}
}
private static void Save(IChannel ch, ModelSaveContext ctx, CodecFactory factory, in VBuffer<ReadOnlyMemory<char>> values)
{
Contracts.AssertValue(ch);
ch.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
// Codec parameterization: A codec parameterization that should be a ReadOnlyMemory codec
// int: n, the number of bytes used to write the values
// byte[n]: As encoded using the codec
// Get the codec from the factory
IValueCodec codec;
var result = factory.TryGetCodec(new VectorDataViewType(TextDataViewType.Instance), out codec);
ch.Assert(result);
VectorDataViewType vectorType = (VectorDataViewType)codec.Type;
ch.Assert(vectorType.Size == 0);
ch.Assert(vectorType.ItemType == TextDataViewType.Instance);
IValueCodec<VBuffer<ReadOnlyMemory<char>>> textCodec = (IValueCodec<VBuffer<ReadOnlyMemory<char>>>)codec;
factory.WriteCodec(ctx.Writer.BaseStream, codec);
using (var mem = new MemoryStream())
{
using (var writer = textCodec.OpenWriter(mem))
{
writer.Write(in values);
writer.Commit();
}
ctx.Writer.WriteByteArray(mem.ToArray());
}
// Make this resemble, more or less, the auxiliary output from the TermTransform.
// It will differ somewhat due to the vector being possibly sparse. To distinguish
// between missing and empty, empties are not written at all, while missings are.
var v = values;
char[] buffer = null;
ctx.SaveTextStream("Terms.txt",
writer =>
{
writer.WriteLine("# Number of terms = {0} of length {1}", v.GetValues().Length, v.Length);
foreach (var pair in v.Items())
{
var text = pair.Value;
if (text.IsEmpty)
continue;
writer.Write("{0}\t", pair.Key);
// REVIEW: What about escaping this, *especially* for linebreaks?
// Do C# and .NET really have no equivalent to Python's "repr"? :(
if (text.IsEmpty)
{
writer.WriteLine();
continue;
}
Utils.EnsureSize(ref buffer, text.Length);
var span = text.Span;
for (int i = 0; i < text.Length; i++)
buffer[i] = span[i];
writer.WriteLine(buffer, 0, text.Length);
}
});
}
public static void LoadAll(IHost host, ModelLoadContext ctx, int infoLim, out VBuffer<ReadOnlyMemory<char>>[] keyValues, out VectorDataViewType[] kvTypes)
{
Contracts.AssertValue(host);
host.AssertValue(ctx);
using (var ch = host.Start("LoadTextValues"))
{
// Try to find the key names.
VBuffer<ReadOnlyMemory<char>>[] keyValuesLocal = null;
VectorDataViewType[] kvTypesLocal = null;
CodecFactory factory = null;
const string dirFormat = "Vocabulary_{0:000}";
for (int iinfo = 0; iinfo < infoLim; iinfo++)
{
ctx.TryProcessSubModel(string.Format(dirFormat, iinfo),
c =>
{
// Load the lazily initialized structures, if needed.
if (keyValuesLocal == null)
{
keyValuesLocal = new VBuffer<ReadOnlyMemory<char>>[infoLim];
kvTypesLocal = new VectorDataViewType[infoLim];
factory = new CodecFactory(host);
}
Load(ch, c, factory, ref keyValuesLocal[iinfo]);
kvTypesLocal[iinfo] = new VectorDataViewType(TextDataViewType.Instance, keyValuesLocal[iinfo].Length);
});
}
keyValues = keyValuesLocal;
kvTypes = kvTypesLocal;
}
}
public static void SaveAll(IHost host, ModelSaveContext ctx, int infoLim, VBuffer<ReadOnlyMemory<char>>[] keyValues)
{
Contracts.AssertValue(host);
host.AssertValue(ctx);
host.AssertValueOrNull(keyValues);
if (keyValues == null)
return;
using (var ch = host.Start("SaveTextValues"))
{
// Save the key names as separate submodels.
const string dirFormat = "Vocabulary_{0:000}";
CodecFactory factory = new CodecFactory(host);
for (int iinfo = 0; iinfo < infoLim; iinfo++)
{
if (keyValues[iinfo].Length == 0)
continue;
ctx.SaveSubModel(string.Format(dirFormat, iinfo),
c => Save(ch, c, factory, in keyValues[iinfo]));
}
}
}
}
}
|