|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Data.Analysis.Tests;
using Microsoft.ML;
using Microsoft.ML.Data;
using Xunit;
using Microsoft.ML.Trainers;
using Xunit.Abstractions;
using Microsoft.ML.TestFramework;
namespace Microsoft.Data.Analysis.Tests
{
public partial class DataFrameIDataViewTests : BaseTestClass
{
public DataFrameIDataViewTests(ITestOutputHelper output) : base(output, true)
{
}
[Fact]
public void TestIDataView()
{
IDataView dataView = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false);
DataDebuggerPreview preview = dataView.Preview();
Assert.Equal(10, preview.RowView.Length);
Assert.Equal(17, preview.ColumnView.Length);
Assert.Equal("Byte", preview.ColumnView[0].Column.Name);
Assert.Equal((byte)0, preview.ColumnView[0].Values[0]);
Assert.Equal((byte)1, preview.ColumnView[0].Values[1]);
Assert.Equal("Decimal", preview.ColumnView[1].Column.Name);
Assert.Equal((double)0, preview.ColumnView[1].Values[0]);
Assert.Equal((double)1, preview.ColumnView[1].Values[1]);
Assert.Equal("Double", preview.ColumnView[2].Column.Name);
Assert.Equal((double)0, preview.ColumnView[2].Values[0]);
Assert.Equal((double)1, preview.ColumnView[2].Values[1]);
Assert.Equal("Float", preview.ColumnView[3].Column.Name);
Assert.Equal((float)0, preview.ColumnView[3].Values[0]);
Assert.Equal((float)1, preview.ColumnView[3].Values[1]);
Assert.Equal("Int", preview.ColumnView[4].Column.Name);
Assert.Equal((int)0, preview.ColumnView[4].Values[0]);
Assert.Equal((int)1, preview.ColumnView[4].Values[1]);
Assert.Equal("Long", preview.ColumnView[5].Column.Name);
Assert.Equal((long)0, preview.ColumnView[5].Values[0]);
Assert.Equal((long)1, preview.ColumnView[5].Values[1]);
Assert.Equal("Sbyte", preview.ColumnView[6].Column.Name);
Assert.Equal((sbyte)0, preview.ColumnView[6].Values[0]);
Assert.Equal((sbyte)1, preview.ColumnView[6].Values[1]);
Assert.Equal("Short", preview.ColumnView[7].Column.Name);
Assert.Equal((short)0, preview.ColumnView[7].Values[0]);
Assert.Equal((short)1, preview.ColumnView[7].Values[1]);
Assert.Equal("Uint", preview.ColumnView[8].Column.Name);
Assert.Equal((uint)0, preview.ColumnView[8].Values[0]);
Assert.Equal((uint)1, preview.ColumnView[8].Values[1]);
Assert.Equal("Ulong", preview.ColumnView[9].Column.Name);
Assert.Equal((ulong)0, preview.ColumnView[9].Values[0]);
Assert.Equal((ulong)1, preview.ColumnView[9].Values[1]);
Assert.Equal("Ushort", preview.ColumnView[10].Column.Name);
Assert.Equal((ushort)0, preview.ColumnView[10].Values[0]);
Assert.Equal((ushort)1, preview.ColumnView[10].Values[1]);
Assert.Equal("String", preview.ColumnView[11].Column.Name);
Assert.Equal("0".ToString(), preview.ColumnView[11].Values[0].ToString());
Assert.Equal("1".ToString(), preview.ColumnView[11].Values[1].ToString());
Assert.Equal("Char", preview.ColumnView[12].Column.Name);
Assert.Equal((ushort)65, preview.ColumnView[12].Values[0]);
Assert.Equal((ushort)66, preview.ColumnView[12].Values[1]);
Assert.Equal("DateTime", preview.ColumnView[13].Column.Name);
Assert.Equal(new DateTime(2021, 06, 04), preview.ColumnView[13].Values[0]);
Assert.Equal(new DateTime(2021, 06, 05), preview.ColumnView[13].Values[1]);
Assert.Equal("Bool", preview.ColumnView[14].Column.Name);
Assert.Equal(true, preview.ColumnView[14].Values[0]);
Assert.Equal(false, preview.ColumnView[14].Values[1]);
Assert.Equal("ArrowString", preview.ColumnView[15].Column.Name);
Assert.Equal("foo".ToString(), preview.ColumnView[15].Values[0].ToString());
Assert.Equal("foo".ToString(), preview.ColumnView[15].Values[1].ToString());
Assert.Equal("VBuffer", preview.ColumnView[16].Column.Name);
Assert.Equal("Dense vector of size 5", preview.ColumnView[16].Values[0].ToString());
Assert.Equal("Dense vector of size 5", preview.ColumnView[16].Values[1].ToString());
}
[Fact]
public void TestIDataViewSchemaInvalidate()
{
DataFrame df = DataFrameTests.MakeDataFrameWithAllMutableColumnTypes(10, withNulls: false);
IDataView dataView = df;
DataViewSchema schema = dataView.Schema;
Assert.Equal(15, schema.Count);
df.Columns.Remove("Bool");
schema = dataView.Schema;
Assert.Equal(14, schema.Count);
DataFrameColumn boolColumn = new PrimitiveDataFrameColumn<bool>("Bool", Enumerable.Range(0, (int)df.Rows.Count).Select(x => x % 2 == 1));
df.Columns.Insert(0, boolColumn);
schema = dataView.Schema;
Assert.Equal(15, schema.Count);
Assert.Equal("Bool", schema[0].Name);
DataFrameColumn boolClone = boolColumn.Clone();
boolClone.SetName("BoolClone");
df.Columns[1] = boolClone;
schema = dataView.Schema;
Assert.Equal("BoolClone", schema[1].Name);
}
[Fact]
public void TestIDataViewWithNulls()
{
int length = 10;
IDataView dataView = DataFrameTests.MakeDataFrameWithAllColumnTypes(length, withNulls: true);
DataDebuggerPreview preview = dataView.Preview();
Assert.Equal(length, preview.RowView.Length);
Assert.Equal(17, preview.ColumnView.Length);
Assert.Equal("Byte", preview.ColumnView[0].Column.Name);
Assert.Equal((byte)0, preview.ColumnView[0].Values[0]);
Assert.Equal((byte)1, preview.ColumnView[0].Values[1]);
Assert.Equal((byte)4, preview.ColumnView[0].Values[4]);
Assert.Equal((byte)0, preview.ColumnView[0].Values[5]); // null row
Assert.Equal((byte)6, preview.ColumnView[0].Values[6]);
Assert.Equal("Decimal", preview.ColumnView[1].Column.Name);
Assert.Equal((double)0, preview.ColumnView[1].Values[0]);
Assert.Equal((double)1, preview.ColumnView[1].Values[1]);
Assert.Equal((double)4, preview.ColumnView[1].Values[4]);
Assert.Equal(double.NaN, preview.ColumnView[1].Values[5]); // null row
Assert.Equal((double)6, preview.ColumnView[1].Values[6]);
Assert.Equal("Double", preview.ColumnView[2].Column.Name);
Assert.Equal((double)0, preview.ColumnView[2].Values[0]);
Assert.Equal((double)1, preview.ColumnView[2].Values[1]);
Assert.Equal((double)4, preview.ColumnView[2].Values[4]);
Assert.Equal(double.NaN, preview.ColumnView[2].Values[5]); // null row
Assert.Equal((double)6, preview.ColumnView[2].Values[6]);
Assert.Equal("Float", preview.ColumnView[3].Column.Name);
Assert.Equal((float)0, preview.ColumnView[3].Values[0]);
Assert.Equal((float)1, preview.ColumnView[3].Values[1]);
Assert.Equal((float)4, preview.ColumnView[3].Values[4]);
Assert.Equal(float.NaN, preview.ColumnView[3].Values[5]); // null row
Assert.Equal((float)6, preview.ColumnView[3].Values[6]);
Assert.Equal("Int", preview.ColumnView[4].Column.Name);
Assert.Equal((int)0, preview.ColumnView[4].Values[0]);
Assert.Equal((int)1, preview.ColumnView[4].Values[1]);
Assert.Equal((int)4, preview.ColumnView[4].Values[4]);
Assert.Equal((int)0, preview.ColumnView[4].Values[5]); // null row
Assert.Equal((int)6, preview.ColumnView[4].Values[6]);
Assert.Equal("Long", preview.ColumnView[5].Column.Name);
Assert.Equal((long)0, preview.ColumnView[5].Values[0]);
Assert.Equal((long)1, preview.ColumnView[5].Values[1]);
Assert.Equal((long)4, preview.ColumnView[5].Values[4]);
Assert.Equal((long)0, preview.ColumnView[5].Values[5]); // null row
Assert.Equal((long)6, preview.ColumnView[5].Values[6]);
Assert.Equal("Sbyte", preview.ColumnView[6].Column.Name);
Assert.Equal((sbyte)0, preview.ColumnView[6].Values[0]);
Assert.Equal((sbyte)1, preview.ColumnView[6].Values[1]);
Assert.Equal((sbyte)4, preview.ColumnView[6].Values[4]);
Assert.Equal((sbyte)0, preview.ColumnView[6].Values[5]); // null row
Assert.Equal((sbyte)6, preview.ColumnView[6].Values[6]);
Assert.Equal("Short", preview.ColumnView[7].Column.Name);
Assert.Equal((short)0, preview.ColumnView[7].Values[0]);
Assert.Equal((short)1, preview.ColumnView[7].Values[1]);
Assert.Equal((short)4, preview.ColumnView[7].Values[4]);
Assert.Equal((short)0, preview.ColumnView[7].Values[5]); // null row
Assert.Equal((short)6, preview.ColumnView[7].Values[6]);
Assert.Equal("Uint", preview.ColumnView[8].Column.Name);
Assert.Equal((uint)0, preview.ColumnView[8].Values[0]);
Assert.Equal((uint)1, preview.ColumnView[8].Values[1]);
Assert.Equal((uint)4, preview.ColumnView[8].Values[4]);
Assert.Equal((uint)0, preview.ColumnView[8].Values[5]); // null row
Assert.Equal((uint)6, preview.ColumnView[8].Values[6]);
Assert.Equal("Ulong", preview.ColumnView[9].Column.Name);
Assert.Equal((ulong)0, preview.ColumnView[9].Values[0]);
Assert.Equal((ulong)1, preview.ColumnView[9].Values[1]);
Assert.Equal((ulong)4, preview.ColumnView[9].Values[4]);
Assert.Equal((ulong)0, preview.ColumnView[9].Values[5]); // null row
Assert.Equal((ulong)6, preview.ColumnView[9].Values[6]);
Assert.Equal("Ushort", preview.ColumnView[10].Column.Name);
Assert.Equal((ushort)0, preview.ColumnView[10].Values[0]);
Assert.Equal((ushort)1, preview.ColumnView[10].Values[1]);
Assert.Equal((ushort)4, preview.ColumnView[10].Values[4]);
Assert.Equal((ushort)0, preview.ColumnView[10].Values[5]); // null row
Assert.Equal((ushort)6, preview.ColumnView[10].Values[6]);
Assert.Equal("String", preview.ColumnView[11].Column.Name);
Assert.Equal("0", preview.ColumnView[11].Values[0].ToString());
Assert.Equal("1", preview.ColumnView[11].Values[1].ToString());
Assert.Equal("4", preview.ColumnView[11].Values[4].ToString());
Assert.Equal("", preview.ColumnView[11].Values[5].ToString()); // null row
Assert.Equal("6", preview.ColumnView[11].Values[6].ToString());
Assert.Equal("Char", preview.ColumnView[12].Column.Name);
Assert.Equal((ushort)65, preview.ColumnView[12].Values[0]);
Assert.Equal((ushort)66, preview.ColumnView[12].Values[1]);
Assert.Equal((ushort)69, preview.ColumnView[12].Values[4]);
Assert.Equal((ushort)0, preview.ColumnView[12].Values[5]); // null row
Assert.Equal((ushort)71, preview.ColumnView[12].Values[6]);
Assert.Equal("DateTime", preview.ColumnView[13].Column.Name);
Assert.Equal(new DateTime(2021, 06, 04), preview.ColumnView[13].Values[0]);
Assert.Equal(new DateTime(2021, 06, 05), preview.ColumnView[13].Values[1]);
Assert.Equal(new DateTime(2021, 06, 08), preview.ColumnView[13].Values[4]);
Assert.Equal(new DateTime(), preview.ColumnView[13].Values[5]); // null row
Assert.Equal(new DateTime(2021, 06, 10), preview.ColumnView[13].Values[6]);
Assert.Equal("Bool", preview.ColumnView[14].Column.Name);
Assert.Equal(true, preview.ColumnView[14].Values[0]);
Assert.Equal(false, preview.ColumnView[14].Values[1]);
Assert.Equal(true, preview.ColumnView[14].Values[4]);
Assert.Equal(false, preview.ColumnView[14].Values[5]); // null row
Assert.Equal(true, preview.ColumnView[14].Values[6]);
Assert.Equal("ArrowString", preview.ColumnView[15].Column.Name);
Assert.Equal("foo", preview.ColumnView[15].Values[0].ToString());
Assert.Equal("foo", preview.ColumnView[15].Values[1].ToString());
Assert.Equal("foo", preview.ColumnView[15].Values[4].ToString());
Assert.Equal("", preview.ColumnView[15].Values[5].ToString()); // null row
Assert.Equal("foo", preview.ColumnView[15].Values[6].ToString());
Assert.Equal("VBuffer", preview.ColumnView[16].Column.Name);
Assert.True(preview.ColumnView[16].Values[0] is VBuffer<int>);
Assert.True(preview.ColumnView[16].Values[6] is VBuffer<int>);
}
[Fact]
public void TestDataFrameFromIDataView()
{
DataFrame df = DataFrameTests.MakeDataFrameWithAllMutableAndArrowColumnTypes(10, withNulls: false);
df.Columns.Remove("Char"); // Because chars are returned as uint16 by IDataView, so end up comparing CharDataFrameColumn to UInt16DataFrameColumn and fail asserts
IDataView dfAsIDataView = df;
DataFrame newDf = dfAsIDataView.ToDataFrame();
Assert.Equal(dfAsIDataView.GetRowCount(), newDf.Rows.Count);
Assert.Equal(dfAsIDataView.Schema.Count, newDf.Columns.Count);
for (int i = 0; i < df.Columns.Count; i++)
{
Assert.True(df.Columns[i].ElementwiseEquals(newDf.Columns[i]).All());
}
}
[Fact]
public void TestDataFrameFromIDataView_SelectColumns()
{
DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false);
IDataView dfAsIDataView = df;
DataFrame newDf = dfAsIDataView.ToDataFrame("Int", "Double");
Assert.Equal(dfAsIDataView.GetRowCount(), newDf.Rows.Count);
Assert.Equal(2, newDf.Columns.Count);
Assert.True(df.Columns["Int"].ElementwiseEquals(newDf.Columns["Int"]).All());
Assert.True(df.Columns["Double"].ElementwiseEquals(newDf.Columns["Double"]).All());
}
[Theory]
[InlineData(10, 5)]
[InlineData(110, 100)]
[InlineData(110, -1)]
public void TestDataFrameFromIDataView_SelectRows(int dataFrameSize, int rowSize)
{
DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(dataFrameSize, withNulls: false);
df.Columns.Remove("Char"); // Because chars are returned as uint16 by DataViewSchema, so end up comparing CharDataFrameColumn to UInt16DataFrameColumn and fail asserts
df.Columns.Remove("Decimal"); // Because decimal is returned as double by DataViewSchema, so end up comparing DecimalDataFrameColumn to DoubleDataFrameColumn and fail asserts
IDataView dfAsIDataView = df;
DataFrame newDf;
if (rowSize == 100)
{
// Test default
newDf = dfAsIDataView.ToDataFrame();
}
else
{
newDf = dfAsIDataView.ToDataFrame(rowSize);
}
if (rowSize == -1)
{
rowSize = dataFrameSize;
}
Assert.Equal(rowSize, newDf.Rows.Count);
Assert.Equal(df.Columns.Count, newDf.Columns.Count);
for (int i = 0; i < newDf.Columns.Count; i++)
{
Assert.Equal(rowSize, newDf.Columns[i].Length);
Assert.Equal(df.Columns[i].Name, newDf.Columns[i].Name);
}
Assert.Equal(dfAsIDataView.Schema.Count, newDf.Columns.Count);
for (int c = 0; c < df.Columns.Count; c++)
{
for (int r = 0; r < rowSize; r++)
{
Assert.Equal(df.Columns[c][r], newDf.Columns[c][r]);
}
}
}
[Fact]
public void TestDataFrameFromIDataView_SelectColumnsAndRows()
{
DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false);
IDataView dfAsIDataView = df;
DataFrame newDf = dfAsIDataView.ToDataFrame(5, "Int", "Double");
Assert.Equal(5, newDf.Rows.Count);
for (int i = 0; i < newDf.Columns.Count; i++)
{
Assert.Equal(5, newDf.Columns[i].Length);
}
Assert.Equal(2, newDf.Columns.Count);
for (int r = 0; r < 5; r++)
{
Assert.Equal(df.Columns["Int"][r], newDf.Columns["Int"][r]);
Assert.Equal(df.Columns["Double"][r], newDf.Columns["Double"][r]);
}
}
private class InputData
{
public string Name { get; set; }
public bool FilterNext { get; set; }
public float Value { get; set; }
}
private IDataView GetASampleIDataView()
{
var mlContext = new MLContext();
// Get a small dataset as an IEnumerable.
var enumerableOfData = new[]
{
new InputData() { Name = "Joey", FilterNext = false, Value = 1.0f },
new InputData() { Name = "Chandler", FilterNext = false , Value = 2.0f},
new InputData() { Name = "Ross", FilterNext = false , Value = 3.0f},
new InputData() { Name = "Monica", FilterNext = true , Value = 4.0f},
new InputData() { Name = "Rachel", FilterNext = true , Value = 5.0f},
new InputData() { Name = "Phoebe", FilterNext = false , Value = 6.0f},
};
IDataView data = mlContext.Data.LoadFromEnumerable(enumerableOfData);
return data;
}
private void VerifyDataFrameColumnAndDataViewColumnValues<T>(string columnName, IDataView data, DataFrame df, int maxRows = -1)
{
int cc = 0;
var nameDataViewColumn = data.GetColumn<T>(columnName);
foreach (var value in nameDataViewColumn)
{
if (maxRows != -1 && cc >= maxRows)
{
return;
}
Assert.Equal(value, df.Columns[columnName][cc++]);
}
}
[Fact]
public void TestDataFrameFromIDataView_MLData()
{
IDataView data = GetASampleIDataView();
DataFrame df = data.ToDataFrame();
Assert.Equal(6, df.Rows.Count);
Assert.Equal(3, df.Columns.Count);
foreach (var column in df.Columns)
{
Assert.Equal(6, column.Length);
}
VerifyDataFrameColumnAndDataViewColumnValues<string>("Name", data, df);
VerifyDataFrameColumnAndDataViewColumnValues<bool>("FilterNext", data, df);
VerifyDataFrameColumnAndDataViewColumnValues<float>("Value", data, df);
}
[Fact]
public void TestDataFrameFromIDataView_MLData_SelectColumns()
{
IDataView data = GetASampleIDataView();
DataFrame df = data.ToDataFrame("Name", "Value");
Assert.Equal(6, df.Rows.Count);
Assert.Equal(2, df.Columns.Count);
foreach (var column in df.Columns)
{
Assert.Equal(6, column.Length);
}
VerifyDataFrameColumnAndDataViewColumnValues<string>("Name", data, df);
VerifyDataFrameColumnAndDataViewColumnValues<float>("Value", data, df);
}
[Theory]
[InlineData(3)]
[InlineData(0)]
public void TestDataFrameFromIDataView_MLData_SelectRows(int maxRows)
{
IDataView data = GetASampleIDataView();
DataFrame df = data.ToDataFrame(maxRows);
Assert.Equal(maxRows, df.Rows.Count);
Assert.Equal(3, df.Columns.Count);
foreach (var column in df.Columns)
{
Assert.Equal(maxRows, column.Length);
}
VerifyDataFrameColumnAndDataViewColumnValues<string>("Name", data, df, maxRows);
VerifyDataFrameColumnAndDataViewColumnValues<bool>("FilterNext", data, df, maxRows);
VerifyDataFrameColumnAndDataViewColumnValues<float>("Value", data, df, maxRows);
}
[Fact]
public void TestDataFrameFromIDataView_MLData_SelectColumnsAndRows()
{
IDataView data = GetASampleIDataView();
DataFrame df = data.ToDataFrame(3, "Name", "Value");
Assert.Equal(3, df.Rows.Count);
Assert.Equal(2, df.Columns.Count);
foreach (var column in df.Columns)
{
Assert.Equal(3, column.Length);
}
VerifyDataFrameColumnAndDataViewColumnValues<string>("Name", data, df, 3);
VerifyDataFrameColumnAndDataViewColumnValues<float>("Value", data, df, 3);
}
[Fact]
public void TestDataFrameFromIDataView_VBufferType()
{
var mlContext = new MLContext();
var inputData = new[]
{
new {
boolFeature = new bool[] {false, false},
byteFeatures = new byte[] {0, 0},
doubleFeatures = new double[] {0, 0},
floatFeatures = new float[] {0, 0},
intFeatures = new int[] {0, 0},
longFeatures = new long[] {0, 0},
sbyteFeatures = new sbyte[] {0, 0},
shortFeatures = new short[] {0, 0},
ushortFeatures = new ushort[] {0, 0},
uintFeatures = new uint[] {0, 0},
ulongFeatures = new ulong[] {0, 0},
stringFeatures = new string[]{ "A", "B"},
},
new {
boolFeature = new bool[] {false, false},
byteFeatures = new byte[] {0, 0},
doubleFeatures = new double[] {0, 0},
floatFeatures = new float[] {1, 1},
intFeatures = new int[] {0, 0},
longFeatures = new long[] {0, 0},
sbyteFeatures = new sbyte[] {0, 0},
shortFeatures = new short[] {0, 0},
ushortFeatures = new ushort[] {0, 0},
uintFeatures = new uint[] {0, 0},
ulongFeatures = new ulong[] {0, 0},
stringFeatures = new string[]{ "A", "B"},
}
};
var data = mlContext.Data.LoadFromEnumerable(inputData);
var df = data.ToDataFrame();
Assert.Equal(12, df.Columns.Count);
Assert.Equal(2, df.Rows.Count);
}
}
}
|