File: CollectionsDataViewTest.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Reflection;
using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.EntryPoints.Tests
{
    public class CollectionsDataViewTest : BaseTestClass
    {
        public CollectionsDataViewTest(ITestOutputHelper output)
            : base(output)
        {
        }
 
        public class ConversionSimpleClass
        {
            public int fInt;
            public uint fuInt;
            public short fShort;
            public ushort fuShort;
            public sbyte fsByte;
            public byte fByte;
            public long fLong;
            public ulong fuLong;
            public float fFloat;
            public double fDouble;
            public bool fBool;
            public string fString = "";
        }
 
        public bool CompareObjectValues(object x, object y, Type type)
        {
            // By default behaviour for ReadOnlyMemory is to be empty string, while for string is null.
            // So if we do roundtrip string-> ReadOnlyMemory -> string all null string become empty strings.
            // Therefore replace all null values to empty string if field is string.
            if (type == typeof(string) && x == null)
                x = "";
            if (type == typeof(string) && y == null)
                y = "";
            if (x == null && y == null)
                return true;
            if (x == null && y != null)
                return false;
            return x.Equals(y);
        }
 
        public bool CompareThroughReflection<T>(T x, T y)
        {
            foreach (var field in typeof(T).GetFields(BindingFlags.Public | BindingFlags.Instance))
            {
                var xvalue = field.GetValue(x);
                var yvalue = field.GetValue(y);
                if (field.FieldType.IsArray)
                {
                    if (!CompareArrayValues(xvalue as Array, yvalue as Array))
                        return false;
                }
                else
                {
                    if (!CompareObjectValues(xvalue, yvalue, field.FieldType))
                        return false;
                }
            }
            foreach (var property in typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance))
            {
                // Don't compare properties with private getters and setters
                if (!property.CanRead || !property.CanWrite || property.GetGetMethod() == null || property.GetSetMethod() == null)
                    continue;
 
                var xvalue = property.GetValue(x);
                var yvalue = property.GetValue(y);
                if (property.PropertyType.IsArray)
                {
                    if (!CompareArrayValues(xvalue as Array, yvalue as Array))
                        return false;
                }
                else
                {
                    if (!CompareObjectValues(xvalue, yvalue, property.PropertyType))
                        return false;
                }
            }
            return true;
        }
 
        public bool CompareArrayValues(Array x, Array y)
        {
            if (x == null && y == null) return true;
            if ((x == null && y != null) || (y == null && x != null))
                return false;
            if (x.Length != y.Length)
                return false;
            for (int i = 0; i < x.Length; i++)
                if (!CompareObjectValues(x.GetValue(i), y.GetValue(i), x.GetType().GetElementType()))
                    return false;
            return true;
        }
 
        [Fact]
        public void RoundTripConversionWithBasicTypes()
        {
            var data = new List<ConversionSimpleClass>
            {
                new ConversionSimpleClass()
                {
                    fInt = int.MaxValue - 1,
                    fuInt = uint.MaxValue - 1,
                    fBool = true,
                    fsByte = sbyte.MaxValue - 1,
                    fByte = byte.MaxValue - 1,
                    fDouble = double.MaxValue - 1,
                    fFloat = float.MaxValue - 1,
                    fLong = long.MaxValue - 1,
                    fuLong = ulong.MaxValue - 1,
                    fShort = short.MaxValue - 1,
                    fuShort = ushort.MaxValue - 1,
                    fString = null
                },
                new ConversionSimpleClass()
                {
                    fInt = int.MaxValue,
                    fuInt = uint.MaxValue,
                    fBool = true,
                    fsByte = sbyte.MaxValue,
                    fByte = byte.MaxValue,
                    fDouble = double.MaxValue,
                    fFloat = float.MaxValue,
                    fLong = long.MaxValue,
                    fuLong = ulong.MaxValue,
                    fShort = short.MaxValue,
                    fuShort = ushort.MaxValue,
                    fString = "ooh"
                },
                new ConversionSimpleClass()
                {
                    fInt = int.MinValue + 1,
                    fuInt = uint.MinValue + 1,
                    fBool = false,
                    fsByte = sbyte.MinValue + 1,
                    fByte = byte.MinValue + 1,
                    fDouble = double.MinValue + 1,
                    fFloat = float.MinValue + 1,
                    fLong = long.MinValue + 1,
                    fuLong = ulong.MinValue + 1,
                    fShort = short.MinValue + 1,
                    fuShort = ushort.MinValue + 1,
                    fString = ""
                },
                new ConversionSimpleClass()
            };
 
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var enumeratorSimple = env.Data.CreateEnumerable<ConversionSimpleClass>(dataView, false).GetEnumerator();
            var originalEnumerator = data.GetEnumerator();
            while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
            {
                Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
            }
            Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
        }
 
        public class ConversionNotSupportedMinValueClass
        {
            public int fInt;
            public long fLong;
            public short fShort;
            public sbyte fSByte;
        }
 
        [Fact]
        public void ConversionExceptionsBehavior()
        {
            var env = new MLContext(1);
            var data = new ConversionNotSupportedMinValueClass[1];
            foreach (var field in typeof(ConversionNotSupportedMinValueClass).GetFields())
            {
                data[0] = new ConversionNotSupportedMinValueClass();
                FieldInfo fi;
                if ((fi = field.FieldType.GetField("MinValue")) != null)
                {
                    field.SetValue(data[0], fi.GetValue(null));
                }
                var dataView = env.Data.LoadFromEnumerable(data);
                var enumerator = env.Data.CreateEnumerable<ConversionNotSupportedMinValueClass>(dataView, false).GetEnumerator();
                try
                {
                    enumerator.MoveNext();
                    Assert.True(false);
                }
                catch
                {
                }
            }
        }
 
        public class ConversionLossMinValueClassProperties
        {
            private int? _fInt;
            private long? _fLong;
            private short? _fShort;
            private sbyte? _fsByte;
            public int? IntProp { get { return _fInt; } set { _fInt = value; } }
            public short? ShortProp { get { return _fShort; } set { _fShort = value; } }
            public sbyte? SByteProp { get { return _fsByte; } set { _fsByte = value; } }
            public long? LongProp { get { return _fLong; } set { _fLong = value; } }
        }
 
        public class ClassWithConstField
        {
            public const string ConstString = "N";
            public string fString;
            public const int ConstInt = 100;
            public int fInt;
        }
 
        [Fact]
        public void ClassWithConstFieldsConversion()
        {
            var data = new List<ClassWithConstField>()
            {
                new ClassWithConstField(){ fInt=1, fString ="lala" },
                new ClassWithConstField(){ fInt=-1, fString ="" },
            };
 
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var enumeratorSimple = env.Data.CreateEnumerable<ClassWithConstField>(dataView, false).GetEnumerator();
            var originalEnumerator = data.GetEnumerator();
            while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
                Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
            Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
        }
 
 
        public class ClassWithMixOfFieldsAndProperties
        {
            public string fString;
            private int _fInt;
            public int IntProp { get { return _fInt; } set { _fInt = value; } }
        }
 
        [Fact]
        public void ClassWithMixOfFieldsAndPropertiesConversion()
        {
            var data = new List<ClassWithMixOfFieldsAndProperties>()
            {
                new ClassWithMixOfFieldsAndProperties(){ IntProp=1, fString ="lala" },
                new ClassWithMixOfFieldsAndProperties(){ IntProp=-1, fString ="" },
            };
 
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var enumeratorSimple = env.Data.CreateEnumerable<ClassWithMixOfFieldsAndProperties>(dataView, false).GetEnumerator();
            var originalEnumerator = data.GetEnumerator();
            while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
                Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
            Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
        }
 
        public abstract class BaseClassWithInheritedProperties
        {
            private string _fString;
            private byte _fByte;
            public string StringProp { get { return _fString; } set { _fString = value; } }
            public abstract long LongProp { get; set; }
            public virtual byte ByteProp { get { return _fByte; } set { _fByte = value; } }
        }
 
 
        public class ClassWithPrivateFieldsAndProperties
        {
            public ClassWithPrivateFieldsAndProperties() { seq++; _unusedStaticField++; _unusedPrivateField1 = 100; }
            public static int seq;
            public static int _unusedStaticField;
            private int _unusedPrivateField1;
            private string _fString;
 
            [NoColumn]
            // This property can be used as source for DataView, but not casting from dataview to collection
            private int UnusedReadOnlyProperty { get { return _unusedPrivateField1; } }
 
            // This property is ignored because it is private 
            private int UnusedPrivateProperty { get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } }
 
            [NoColumn]
            // This property can be used as source for DataView, but not casting from dataview to collection
            public int UnusedPropertyWithPrivateSetter { get { return _unusedPrivateField1; } private set { _unusedPrivateField1 = value; } }
 
            [NoColumn]
            // This property can be used as receptacle for dataview, but not as source for dataview.
            public int UnusedPropertyWithPrivateGetter { private get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } }
 
            public string StringProp { get { return _fString; } set { _fString = value; } }
        }
 
        [Fact]
        public void ClassWithPrivateFieldsAndPropertiesConversion()
        {
            var data = new List<ClassWithPrivateFieldsAndProperties>()
            {
                new ClassWithPrivateFieldsAndProperties(){ StringProp ="lala" },
                new ClassWithPrivateFieldsAndProperties(){ StringProp ="baba" }
            };
 
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var enumeratorSimple = env.Data.CreateEnumerable<ClassWithPrivateFieldsAndProperties>(dataView, false).GetEnumerator();
            var originalEnumerator = data.GetEnumerator();
            while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
            {
                Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
                Assert.True(enumeratorSimple.Current.UnusedPropertyWithPrivateSetter == 100);
            }
            Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
        }
 
        public class ClassWithInheritedProperties : BaseClassWithInheritedProperties
        {
            private long _fLong;
            private byte _fByte2;
            public int IntProp { get; set; }
            public override long LongProp { get => _fLong; set => _fLong = value; }
            public override byte ByteProp { get => _fByte2; set => _fByte2 = value; }
        }
 
        [Fact]
        public void ClassWithInheritedPropertiesConversion()
        {
            var data = new List<ClassWithInheritedProperties>()
            {
                new ClassWithInheritedProperties(){ IntProp=1, StringProp ="lala", LongProp=17, ByteProp=3 },
                new ClassWithInheritedProperties(){ IntProp=-1, StringProp ="", LongProp=2, ByteProp=4 },
            };
 
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var enumeratorSimple = env.Data.CreateEnumerable<ClassWithInheritedProperties>(dataView, false).GetEnumerator();
            var originalEnumerator = data.GetEnumerator();
            while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
                Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
            Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
        }
 
        public class ClassWithArrays
        {
            public string[] fString;
            public int[] fInt;
            public uint[] fuInt;
            public short[] fShort;
            public ushort[] fuShort;
            public sbyte[] fsByte;
            public byte[] fByte;
            public long[] fLong;
            public ulong[] fuLong;
            public float[] fFloat;
            public double[] fDouble;
            public bool[] fBool;
        }
 
        [Fact]
        public void RoundTripConversionWithArrays()
        {
 
            var data = new List<ClassWithArrays>
            {
                new ClassWithArrays()
                {
                    fInt = new int[3] { 0, 1, 2 },
                    fFloat = new float[3] { -0.99f, 0f, 0.99f },
                    fString = new string[2] { "hola", "lola" },
 
                    fByte = new byte[3] { 0, 124, 255 },
                    fDouble = new double[3] { -1, 0, 1 },
                    fLong = new long[] { 0, 1, 2 },
                    fsByte = new sbyte[3] { -127, 127, 0 },
                    fShort = new short[3] { 0, 1225, 32767 },
                    fuInt = new uint[2] { 0, uint.MaxValue },
                    fuLong = new ulong[2] { ulong.MaxValue, 0 },
                    fuShort = new ushort[2] { 0, ushort.MaxValue }
                },
                new ClassWithArrays() { fInt = new int[3] { -2, 1, 0 }, fFloat = new float[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "", null } },
                new ClassWithArrays()
            };
 
 
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var enumeratorSimple = env.Data.CreateEnumerable<ClassWithArrays>(dataView, false).GetEnumerator();
            var originalEnumerator = data.GetEnumerator();
            while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
            {
                Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
            }
            Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
        }
        public class ClassWithArrayProperties
        {
            public string[] StringProp { get; set; }
            public int[] IntProp { get; set; }
            public uint[] UIntProp { get; set; }
            public short[] ShortProp { get; set; }
            public ushort[] UShortProp { get; set; }
            public sbyte[] SByteProp { get; set; }
            public byte[] ByteProp { get; set; }
            public long[] LongProp { get; set; }
            public ulong[] ULongProp { get; set; }
            public float[] FloatProp { get; set; }
            public double[] DobuleProp { get; set; }
            public bool[] BoolProp { get; set; }
        }
 
        [Fact]
        public void RoundTripConversionWithArrayPropertiess()
        {
 
            var data = new List<ClassWithArrayProperties>
            {
                new ClassWithArrayProperties()
                {
                    IntProp = new int[3] { 0, 1, 2 },
                    FloatProp = new float[3] { -0.99f, 0f, 0.99f },
                    StringProp = new string[2] { "hola", "lola" },
                    BoolProp = new bool[2] { true, false },
                    ByteProp = new byte[3] { 0, 124, 255 },
                    DobuleProp = new double[3] { -1, 0, 1 },
                    LongProp = new long[] { 0, 1, 2 },
                    SByteProp = new sbyte[3] { -127, 127, 0 },
                    ShortProp = new short[3] { 0, 1225, 32767 },
                    UIntProp = new uint[2] { 0, uint.MaxValue },
                    ULongProp = new ulong[2] { ulong.MaxValue, 0 },
                    UShortProp = new ushort[2] { 0, ushort.MaxValue }
                },
                new ClassWithArrayProperties() { IntProp = new int[3] { -2, 1, 0 }, FloatProp = new float[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "", null } },
                new ClassWithArrayProperties()
            };
 
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var enumeratorSimple = env.Data.CreateEnumerable<ClassWithArrayProperties>(dataView, false).GetEnumerator();
            var originalEnumerator = data.GetEnumerator();
            while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
                Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
            Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
        }
 
        private sealed class ClassWithGetter
        {
            private DateTime _dateTime = DateTime.Now;
            public float Day => _dateTime.Day;
            public int Hour => _dateTime.Hour;
        }
 
        private sealed class ClassWithSetter
        {
            public float Day { private get; set; }
            public int Hour { private get; set; }
 
            [NoColumn]
            public float GetDay => Day;
            [NoColumn]
            public int GetHour => Hour;
        }
 
        [Fact]
        public void PrivateGetSetProperties()
        {
            var data = new List<ClassWithGetter>()
            {
                new ClassWithGetter(),
                new ClassWithGetter(),
                new ClassWithGetter()
            };
 
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var enumeratorSimple = env.Data.CreateEnumerable<ClassWithSetter>(dataView, false).GetEnumerator();
            var originalEnumerator = data.GetEnumerator();
            while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
            {
                Assert.True(enumeratorSimple.Current.GetDay == originalEnumerator.Current.Day &&
                    enumeratorSimple.Current.GetHour == originalEnumerator.Current.Hour);
            }
        }
    }
}