|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;
namespace Microsoft.ML.RunTests
{
public sealed class TestTransposer : TestDataPipeBase
{
public TestTransposer(ITestOutputHelper helper) : base(helper)
{
}
private static T[] NaiveTranspose<T>(IDataView view, int col)
{
var type = view.Schema[col].Type;
int rc = checked((int)DataViewUtils.ComputeRowCount(view));
var vecType = type as VectorDataViewType;
var itemType = vecType?.ItemType ?? type;
Assert.Equal(typeof(T), itemType.RawType);
Assert.NotEqual(0, vecType?.Size);
T[] retval = new T[rc * (vecType?.Size ?? 1)];
using (var cursor = view.GetRowCursor(view.Schema[col]))
{
if (type is VectorDataViewType)
{
var getter = cursor.GetGetter<VBuffer<T>>(cursor.Schema[col]);
VBuffer<T> temp = default;
int offset = 0;
while (cursor.MoveNext())
{
Assert.True(0 <= offset && offset < rc && offset == cursor.Position);
getter(ref temp);
var tempValues = temp.GetValues();
var tempIndices = temp.GetIndices();
for (int i = 0; i < tempValues.Length; ++i)
retval[(temp.IsDense ? i : tempIndices[i]) * rc + offset] = tempValues[i];
offset++;
}
}
else
{
var getter = cursor.GetGetter<T>(cursor.Schema[col]);
while (cursor.MoveNext())
{
Assert.True(0 <= cursor.Position && cursor.Position < rc);
getter(ref retval[(int)cursor.Position]);
}
}
}
return retval;
}
private static void TransposeCheckHelper<T>(IDataView view, int viewCol, ITransposeDataView trans)
{
Assert.NotNull(view);
Assert.NotNull(trans);
int col = viewCol;
VectorDataViewType type = trans.GetSlotType(col);
DataViewType colType = trans.Schema[col].Type;
Assert.Equal(view.Schema[viewCol].Name, trans.Schema[col].Name);
DataViewType expectedType = view.Schema[viewCol].Type;
Assert.Equal(expectedType, colType);
string desc = string.Format("Column {0} named '{1}'", col, trans.Schema[col].Name);
Assert.Equal(DataViewUtils.ComputeRowCount(view), type.Size);
Assert.True(typeof(T) == type.ItemType.RawType, $"{desc} had wrong type for slot cursor");
Assert.True(type.Size > 0, $"{desc} expected to be known sized vector but is not");
int valueCount = (colType as VectorDataViewType)?.Size ?? 1;
Assert.True(0 != valueCount, $"{desc} expected to have fixed size, but does not");
int rc = type.Size;
T[] expectedVals = NaiveTranspose<T>(view, viewCol);
T[] vals = new T[rc * valueCount];
Contracts.Assert(vals.Length == expectedVals.Length);
using (var cursor = trans.GetSlotCursor(col))
{
var getter = cursor.GetGetter<T>();
VBuffer<T> temp = default(VBuffer<T>);
int offset = 0;
while (cursor.MoveNext())
{
Assert.True(offset < vals.Length, $"{desc} slot cursor went further than it should have");
getter(ref temp);
Assert.True(rc == temp.Length, $"{desc} slot cursor yielded vector with unexpected length");
temp.CopyTo(vals, offset);
offset += rc;
}
Assert.True(valueCount == offset / rc, $"{desc} slot cursor yielded fewer than expected values");
}
for (int i = 0; i < vals.Length; ++i)
Assert.Equal(expectedVals[i], vals[i]);
}
private static VBuffer<T>[] GenerateHelper<T>(
int rowCount, Double density, Random rgen, Func<T> generator, int slotCount, params int[] forceDenseSlot)
{
HashSet<int> forceDenseSlotSet = new HashSet<int>(forceDenseSlot);
VBuffer<T>[] vecs = new VBuffer<T>[rowCount];
for (int r = 0; r < vecs.Length; ++r)
{
// Density controls both the prevelence of dense arrays, as well as the sparsity of the sparse arrays.
if (rgen.NextDouble() < density)
{
// Must be dense.
T[] vals = new T[slotCount];
for (int i = 0; i < vals.Length; ++i)
vals[i] = generator();
vecs[r] = new VBuffer<T>(slotCount, vals);
}
else
{
// Must be sparse.
List<int> indices = new List<int>();
for (int i = 0; i < slotCount; ++i)
{
if (forceDenseSlotSet.Contains(i) || rgen.NextDouble() < density)
indices.Add(i);
}
T[] vals = new T[indices.Count];
for (int i = 0; i < vals.Length; ++i)
vals[i] = generator();
vecs[r] = new VBuffer<T>(slotCount, indices.Count, vals, indices.ToArray());
}
}
return vecs;
}
private static T[] GenerateHelper<T>(int rowCount, Double density, Random rgen, Func<T> generator)
{
T[] values = new T[rowCount];
for (int r = 0; r < values.Length; ++r)
{
if (rgen.NextDouble() < density)
values[r] = generator();
}
return values;
}
[Fact]
[TestCategory("Transposer")]
public void TransposerTest()
{
const int rowCount = 1000;
Random rgen = new Random(0);
ArrayDataViewBuilder builder = new ArrayDataViewBuilder(Env);
// A is to check the splitting of a sparse-ish column.
var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (int)rgen.Next(), 50, 5, 10, 15);
dataA[rowCount / 2] = new VBuffer<int>(50, 0, null, null); // Coverage for the null vbuffer case.
builder.AddColumn("A", NumberDataViewType.Int32, dataA);
// B is to check the splitting of a dense-ish column.
builder.AddColumn("B", NumberDataViewType.Double, GenerateHelper(rowCount, 0.8, rgen, rgen.NextDouble, 50, 0, 25, 49));
// C is to just have some column we do nothing with.
builder.AddColumn("C", NumberDataViewType.Int16, GenerateHelper(rowCount, 0.1, rgen, () => (short)1, 30, 3, 10, 24));
// D is to check some column we don't have to split because it's sufficiently small.
builder.AddColumn("D", NumberDataViewType.Double, GenerateHelper(rowCount, 0.1, rgen, rgen.NextDouble, 3, 1));
// E is to check a sparse scalar column.
builder.AddColumn("E", NumberDataViewType.UInt32, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue)));
// F is to check a dense-ish scalar column.
builder.AddColumn("F", NumberDataViewType.Int32, GenerateHelper(rowCount, 0.8, rgen, () => rgen.Next()));
IDataView view = builder.GetDataView();
// Do not force save. This will have a mix of passthrough and saved columns. Note that duplicate
// specification of "D" to test that specifying a column twice has no ill effects.
string[] names = { "B", "A", "E", "D", "F", "D" };
using (Transposer trans = Transposer.Create(Env, view, false, names))
{
// Before checking the contents, check the names.
for (int i = 0; i < names.Length; ++i)
{
int index;
Assert.True(trans.Schema.TryGetColumnIndex(names[i], out index), $"Transpose schema couldn't find column '{names[i]}'");
int trueIndex;
bool result = view.Schema.TryGetColumnIndex(names[i], out trueIndex);
Contracts.Assert(result);
Assert.True(trueIndex == index, $"Transpose schema had column '{names[i]}' at unexpected index");
}
// Check the contents
Assert.Null(((ITransposeDataView)trans).GetSlotType(2)); // C check to see that it's not transposable.
TransposeCheckHelper<int>(view, 0, trans); // A check.
TransposeCheckHelper<Double>(view, 1, trans); // B check.
TransposeCheckHelper<Double>(view, 3, trans); // D check.
TransposeCheckHelper<uint>(view, 4, trans); // E check.
TransposeCheckHelper<int>(view, 5, trans); // F check.
}
// Force save. Recheck columns that would have previously been passthrough columns.
// The primary benefit of this check is that we check the binary saving / loading
// functionality of scalars which are otherwise always must necessarily be
// passthrough. Also exercise the select by index functionality while we're at it.
using (Transposer trans = Transposer.Create(Env, view, true, 3, 5, 4))
{
// Check to see that A, B, and C were not transposed somehow.
var itdv = (ITransposeDataView)trans;
Assert.Null(itdv.GetSlotType(0));
Assert.Null(itdv.GetSlotType(1));
Assert.Null(itdv.GetSlotType(2));
TransposeCheckHelper<Double>(view, 3, trans); // D check.
TransposeCheckHelper<uint>(view, 4, trans); // E check.
TransposeCheckHelper<int>(view, 5, trans); // F check.
}
}
[Fact]
[TestCategory("Transposer")]
public void TransposerSaverLoaderTest()
{
const int rowCount = 1000;
Random rgen = new Random(1);
ArrayDataViewBuilder builder = new ArrayDataViewBuilder(Env);
// A is to check the splitting of a sparse-ish column.
var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (int)rgen.Next(), 50, 5, 10, 15);
dataA[rowCount / 2] = new VBuffer<int>(50, 0, null, null); // Coverage for the null vbuffer case.
builder.AddColumn("A", NumberDataViewType.Int32, dataA);
// B is to check the splitting of a dense-ish column.
builder.AddColumn("B", NumberDataViewType.Double, GenerateHelper(rowCount, 0.8, rgen, rgen.NextDouble, 50, 0, 25, 49));
// C is to just have some column we do nothing with.
builder.AddColumn("C", NumberDataViewType.Int16, GenerateHelper(rowCount, 0.1, rgen, () => (short)1, 30, 3, 10, 24));
// D is to check some column we don't have to split because it's sufficiently small.
builder.AddColumn("D", NumberDataViewType.Double, GenerateHelper(rowCount, 0.1, rgen, rgen.NextDouble, 3, 1));
// E is to check a sparse scalar column.
builder.AddColumn("E", NumberDataViewType.UInt32, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue)));
// F is to check a dense-ish scalar column.
builder.AddColumn("F", NumberDataViewType.Int32, GenerateHelper(rowCount, 0.8, rgen, () => (int)rgen.Next()));
IDataView view = builder.GetDataView();
IMultiStreamSource src;
using (MemoryStream mem = new MemoryStream())
{
TransposeSaver saver = new TransposeSaver(Env, new TransposeSaver.Arguments());
saver.SaveData(mem, view, Utils.GetIdentityPermutation(view.Schema.Count));
src = new BytesStreamSource(mem.ToArray());
}
TransposeLoader loader = new TransposeLoader(Env, new TransposeLoader.Arguments(), src);
// First check whether this as an IDataView yields the same values.
CheckSameValues(view, loader);
TransposeCheckHelper<int>(view, 0, loader); // A
TransposeCheckHelper<Double>(view, 1, loader); // B
TransposeCheckHelper<short>(view, 2, loader); // C
TransposeCheckHelper<Double>(view, 3, loader); // D
TransposeCheckHelper<uint>(view, 4, loader); // E
TransposeCheckHelper<int>(view, 5, loader); // F
Done();
}
}
}
|