File: ArrayComparer.cs
Web Access
Project: src\test\Microsoft.Data.Analysis.Tests\Microsoft.Data.Analysis.Tests.csproj (Microsoft.Data.Analysis.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using Apache.Arrow;
using Xunit;
 
namespace Microsoft.Data.Analysis.Tests
{
    public class ArrayComparer :
        IArrowArrayVisitor<Int8Array>,
        IArrowArrayVisitor<Int16Array>,
        IArrowArrayVisitor<Int32Array>,
        IArrowArrayVisitor<Int64Array>,
        IArrowArrayVisitor<UInt8Array>,
        IArrowArrayVisitor<UInt16Array>,
        IArrowArrayVisitor<UInt32Array>,
        IArrowArrayVisitor<UInt64Array>,
        IArrowArrayVisitor<FloatArray>,
        IArrowArrayVisitor<DoubleArray>,
        IArrowArrayVisitor<BooleanArray>,
        IArrowArrayVisitor<TimestampArray>,
        IArrowArrayVisitor<Date32Array>,
        IArrowArrayVisitor<Date64Array>,
        IArrowArrayVisitor<ListArray>,
        IArrowArrayVisitor<StringArray>,
        IArrowArrayVisitor<BinaryArray>,
        IArrowArrayVisitor<StructArray>
    {
        private readonly IArrowArray _expectedArray;
 
        public ArrayComparer(IArrowArray expectedArray)
        {
            _expectedArray = expectedArray;
        }
 
        public void Visit(Int8Array array) => CompareArrays(array);
        public void Visit(Int16Array array) => CompareArrays(array);
        public void Visit(Int32Array array) => CompareArrays(array);
        public void Visit(Int64Array array) => CompareArrays(array);
        public void Visit(UInt8Array array) => CompareArrays(array);
        public void Visit(UInt16Array array) => CompareArrays(array);
        public void Visit(UInt32Array array) => CompareArrays(array);
        public void Visit(UInt64Array array) => CompareArrays(array);
        public void Visit(FloatArray array) => CompareArrays(array);
        public void Visit(DoubleArray array) => CompareArrays(array);
        public void Visit(BooleanArray array) => CompareArrays(array);
        public void Visit(TimestampArray array) => CompareArrays(array);
        public void Visit(Date32Array array) => CompareArrays(array);
        public void Visit(Date64Array array) => CompareArrays(array);
        public void Visit(ListArray array) => throw new NotImplementedException();
        public void Visit(StringArray array) => CompareArrays(array);
        public void Visit(BinaryArray array) => throw new NotImplementedException();
        public void Visit(IArrowArray array) => throw new NotImplementedException();
 
        public void Visit(StructArray array)
        {
            Assert.IsAssignableFrom<StructArray>(_expectedArray);
            StructArray expectedArray = (StructArray)_expectedArray;
 
            Assert.Equal(expectedArray.Length, array.Length);
            Assert.Equal(expectedArray.NullCount, array.NullCount);
            Assert.Equal(expectedArray.Offset, array.Offset);
            Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length);
            Assert.Equal(expectedArray.Fields.Count, array.Fields.Count);
 
            for (int i = 0; i < array.Fields.Count; i++)
            {
                array.Fields[i].Accept(new ArrayComparer(expectedArray.Fields[i]));
            }
        }
 
        private void CompareArrays<T>(PrimitiveArray<T> actualArray)
            where T : struct, IEquatable<T>
        {
            Assert.IsAssignableFrom<PrimitiveArray<T>>(_expectedArray);
            PrimitiveArray<T> expectedArray = (PrimitiveArray<T>)_expectedArray;
 
            Assert.Equal(expectedArray.Length, actualArray.Length);
            Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
            Assert.Equal(expectedArray.Offset, actualArray.Offset);
 
            if (expectedArray.NullCount > 0)
            {
                Assert.True(expectedArray.NullBitmapBuffer.Span.SequenceEqual(actualArray.NullBitmapBuffer.Span));
            }
            else
            {
                // expectedArray may have passed in a null bitmap. DataFrame might have populated it with Length set bits 
                Assert.Equal(0, expectedArray.NullCount);
                Assert.Equal(0, actualArray.NullCount);
                for (int i = 0; i < actualArray.Length; i++)
                {
                    Assert.True(actualArray.IsValid(i));
                }
            }
            Assert.True(expectedArray.Values.Slice(0, expectedArray.Length).SequenceEqual(actualArray.Values.Slice(0, actualArray.Length)));
        }
 
        private void CompareArrays(BooleanArray actualArray)
        {
            Assert.IsAssignableFrom<BooleanArray>(_expectedArray);
            BooleanArray expectedArray = (BooleanArray)_expectedArray;
 
            Assert.Equal(expectedArray.Length, actualArray.Length);
            Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
            Assert.Equal(expectedArray.Offset, actualArray.Offset);
 
            Assert.True(expectedArray.NullBitmapBuffer.Span.SequenceEqual(actualArray.NullBitmapBuffer.Span));
            int booleanByteCount = BitUtility.ByteCount(expectedArray.Length);
            Assert.True(expectedArray.Values.Slice(0, booleanByteCount).SequenceEqual(actualArray.Values.Slice(0, booleanByteCount)));
        }
 
        private void CompareArrays(StringArray actualArray)
        {
            Assert.IsAssignableFrom<StringArray>(_expectedArray);
            StringArray expectedArray = (StringArray)_expectedArray;
 
            Assert.Equal(expectedArray.Length, actualArray.Length);
            Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
            Assert.Equal(expectedArray.Offset, actualArray.Offset);
 
            Assert.True(expectedArray.NullBitmapBuffer.Span.SequenceEqual(actualArray.NullBitmapBuffer.Span));
            Assert.True(expectedArray.Values.Slice(0, expectedArray.Length).SequenceEqual(actualArray.Values.Slice(0, actualArray.Length)));
        }
    }
 
    internal static class FieldComparer
    {
        public static bool Equals(Field f1, Field f2)
        {
            if (ReferenceEquals(f1, f2))
            {
                return true;
            }
            if (f2 != null && f1 != null && f1.Name == f2.Name && f1.IsNullable == f2.IsNullable &&
                f1.DataType.TypeId == f2.DataType.TypeId && f1.HasMetadata == f2.HasMetadata)
            {
                if (f1.HasMetadata && f2.HasMetadata)
                {
                    return f1.Metadata.Keys.Count() == f2.Metadata.Keys.Count() &&
                           f1.Metadata.Keys.All(k => f2.Metadata.ContainsKey(k) && f1.Metadata[k] == f2.Metadata[k]) &&
                           f2.Metadata.Keys.All(k => f1.Metadata.ContainsKey(k) && f2.Metadata[k] == f1.Metadata[k]);
                }
                return true;
            }
            return false;
        }
    }
 
    internal static class SchemaComparer
    {
        public static bool Equals(Schema s1, Schema s2)
        {
            if (ReferenceEquals(s1, s2))
            {
                return true;
            }
            if (s2 == null || s1 == null || s1.HasMetadata != s2.HasMetadata || s1.FieldsList.Count != s2.FieldsList.Count)
            {
                return false;
            }
 
            if (!s1.FieldsList.All(field => s2.FieldsLookup.Contains(field.Name) && FieldComparer.Equals(field, s2.GetFieldByName(field.Name))) ||
                !s2.FieldsList.All(field => s1.FieldsLookup.Contains(field.Name) && FieldComparer.Equals(field, s1.GetFieldByName(field.Name))))
            {
                return false;
            }
 
            if (s1.HasMetadata && s2.HasMetadata)
            {
                return s1.Metadata.Keys.Count() == s2.Metadata.Keys.Count() &&
                       s1.Metadata.Keys.All(k => s2.Metadata.ContainsKey(k) && s1.Metadata[k] == s2.Metadata[k]) &&
                       s2.Metadata.Keys.All(k => s1.Metadata.ContainsKey(k) && s2.Metadata[k] == s1.Metadata[k]);
            }
            return true;
        }
    }
 
    public static class RecordBatchComparer
    {
        public static void CompareBatches(RecordBatch expectedBatch, RecordBatch actualBatch)
        {
            Assert.True(SchemaComparer.Equals(expectedBatch.Schema, actualBatch.Schema));
            Assert.Equal(expectedBatch.Length, actualBatch.Length);
            Assert.Equal(expectedBatch.ColumnCount, actualBatch.ColumnCount);
 
            for (int i = 0; i < expectedBatch.ColumnCount; i++)
            {
                IArrowArray expectedArray = expectedBatch.Arrays.ElementAt(i);
                IArrowArray actualArray = actualBatch.Arrays.ElementAt(i);
 
                actualArray.Accept(new ArrayComparer(expectedArray));
            }
        }
 
    }
}