File: DataFrame.Arrow.cs
Web Access
Project: src\src\Microsoft.Data.Analysis\Microsoft.Data.Analysis.csproj (Microsoft.Data.Analysis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Apache.Arrow;
using Apache.Arrow.Types;
 
namespace Microsoft.Data.Analysis
{
    public partial class DataFrame
    {
        private static void AppendDataFrameColumnFromArrowArray(Field field, IArrowArray arrowArray, DataFrame ret, string fieldNamePrefix = "")
        {
            IArrowType fieldType = field.DataType;
            DataFrameColumn dataFrameColumn = null;
            string fieldName = fieldNamePrefix + field.Name;
            switch (fieldType.TypeId)
            {
                case ArrowTypeId.Boolean:
                    BooleanArray arrowBooleanArray = (BooleanArray)arrowArray;
                    ReadOnlyMemory<byte> valueBuffer = arrowBooleanArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> nullBitMapBuffer = arrowBooleanArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new BooleanDataFrameColumn(fieldName, valueBuffer, nullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.Double:
                    PrimitiveArray<double> arrowDoubleArray = (PrimitiveArray<double>)arrowArray;
                    ReadOnlyMemory<byte> doubleValueBuffer = arrowDoubleArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> doubleNullBitMapBuffer = arrowDoubleArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new DoubleDataFrameColumn(fieldName, doubleValueBuffer, doubleNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.Float:
                    PrimitiveArray<float> arrowFloatArray = (PrimitiveArray<float>)arrowArray;
                    ReadOnlyMemory<byte> floatValueBuffer = arrowFloatArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> floatNullBitMapBuffer = arrowFloatArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new SingleDataFrameColumn(fieldName, floatValueBuffer, floatNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.Int8:
                    PrimitiveArray<sbyte> arrowsbyteArray = (PrimitiveArray<sbyte>)arrowArray;
                    ReadOnlyMemory<byte> sbyteValueBuffer = arrowsbyteArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> sbyteNullBitMapBuffer = arrowsbyteArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new SByteDataFrameColumn(fieldName, sbyteValueBuffer, sbyteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.Int16:
                    PrimitiveArray<short> arrowshortArray = (PrimitiveArray<short>)arrowArray;
                    ReadOnlyMemory<byte> shortValueBuffer = arrowshortArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> shortNullBitMapBuffer = arrowshortArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new Int16DataFrameColumn(fieldName, shortValueBuffer, shortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.Int32:
                    PrimitiveArray<int> arrowIntArray = (PrimitiveArray<int>)arrowArray;
                    ReadOnlyMemory<byte> intValueBuffer = arrowIntArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> intNullBitMapBuffer = arrowIntArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new Int32DataFrameColumn(fieldName, intValueBuffer, intNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.Int64:
                    PrimitiveArray<long> arrowLongArray = (PrimitiveArray<long>)arrowArray;
                    ReadOnlyMemory<byte> longValueBuffer = arrowLongArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> longNullBitMapBuffer = arrowLongArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new Int64DataFrameColumn(fieldName, longValueBuffer, longNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.String:
                    StringArray stringArray = (StringArray)arrowArray;
                    ReadOnlyMemory<byte> dataMemory = stringArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> offsetsMemory = stringArray.ValueOffsetsBuffer.Memory;
                    ReadOnlyMemory<byte> nullMemory = stringArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new ArrowStringDataFrameColumn(fieldName, dataMemory, offsetsMemory, nullMemory, stringArray.Length, stringArray.NullCount);
                    break;
                case ArrowTypeId.UInt8:
                    PrimitiveArray<byte> arrowbyteArray = (PrimitiveArray<byte>)arrowArray;
                    ReadOnlyMemory<byte> byteValueBuffer = arrowbyteArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> byteNullBitMapBuffer = arrowbyteArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new ByteDataFrameColumn(fieldName, byteValueBuffer, byteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.UInt16:
                    PrimitiveArray<ushort> arrowUshortArray = (PrimitiveArray<ushort>)arrowArray;
                    ReadOnlyMemory<byte> ushortValueBuffer = arrowUshortArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> ushortNullBitMapBuffer = arrowUshortArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new UInt16DataFrameColumn(fieldName, ushortValueBuffer, ushortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.UInt32:
                    PrimitiveArray<uint> arrowUintArray = (PrimitiveArray<uint>)arrowArray;
                    ReadOnlyMemory<byte> uintValueBuffer = arrowUintArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> uintNullBitMapBuffer = arrowUintArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new UInt32DataFrameColumn(fieldName, uintValueBuffer, uintNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.UInt64:
                    PrimitiveArray<ulong> arrowUlongArray = (PrimitiveArray<ulong>)arrowArray;
                    ReadOnlyMemory<byte> ulongValueBuffer = arrowUlongArray.ValueBuffer.Memory;
                    ReadOnlyMemory<byte> ulongNullBitMapBuffer = arrowUlongArray.NullBitmapBuffer.Memory;
                    dataFrameColumn = new UInt64DataFrameColumn(fieldName, ulongValueBuffer, ulongNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
                    break;
                case ArrowTypeId.Struct:
                    StructArray structArray = (StructArray)arrowArray;
                    StructType structType = (StructType)field.DataType;
                    IEnumerator<Field> fieldsEnumerator = structType.Fields.GetEnumerator();
                    IEnumerator<IArrowArray> structArrayEnumerator = structArray.Fields.GetEnumerator();
                    while (fieldsEnumerator.MoveNext() && structArrayEnumerator.MoveNext())
                    {
                        AppendDataFrameColumnFromArrowArray(fieldsEnumerator.Current, structArrayEnumerator.Current, ret, field.Name + "_");
                    }
                    break;
                case ArrowTypeId.Date64:
                    {
                        Date64Array arrowDate64Array = (Date64Array)arrowArray;
                        var dataTimeDataFrameColumn = new DateTimeDataFrameColumn(fieldName, arrowDate64Array.Data.Length);
                        for (int i = 0; i < arrowDate64Array.Data.Length; i++)
                        {
                            dataTimeDataFrameColumn[i] = arrowDate64Array.GetDateTime(i);
                        }
                        dataFrameColumn = dataTimeDataFrameColumn;
                    }
                    break;
                case ArrowTypeId.Timestamp:
                    {
                        TimestampArray arrowTimeStampArray = (TimestampArray)arrowArray;
                        var dataTimeDataFrameColumn = new DateTimeDataFrameColumn(fieldName, arrowTimeStampArray.Data.Length);
                        for (int i = 0; i < arrowTimeStampArray.Data.Length; i++)
                        {
                            dataTimeDataFrameColumn[i] = arrowTimeStampArray.GetTimestamp(i)?.DateTime;
                        }
                        dataFrameColumn = dataTimeDataFrameColumn;
                    }
                    break;
                case ArrowTypeId.Decimal128:
                case ArrowTypeId.Decimal256:
                case ArrowTypeId.Binary:
                case ArrowTypeId.Date32:
                case ArrowTypeId.Dictionary:
                case ArrowTypeId.FixedSizedBinary:
                case ArrowTypeId.HalfFloat:
                case ArrowTypeId.Interval:
                case ArrowTypeId.List:
                case ArrowTypeId.Map:
                case ArrowTypeId.Null:
                case ArrowTypeId.Time32:
                case ArrowTypeId.Time64:
 
                default:
                    throw new NotImplementedException($"{fieldType.Name}");
            }
 
            if (dataFrameColumn != null)
            {
                ret.Columns.Insert(ret.Columns.Count, dataFrameColumn);
            }
        }
 
        /// <summary>
        /// Wraps a <see cref="DataFrame"/> around an Arrow <see cref="RecordBatch"/> without copying data
        /// </summary>
        /// <param name="recordBatch"></param>
        /// <returns><see cref="DataFrame"/></returns>
        public static DataFrame FromArrowRecordBatch(RecordBatch recordBatch)
        {
            DataFrame ret = new DataFrame();
            Apache.Arrow.Schema arrowSchema = recordBatch.Schema;
            int fieldIndex = 0;
            IEnumerable<IArrowArray> arrowArrays = recordBatch.Arrays;
            foreach (IArrowArray arrowArray in arrowArrays)
            {
                Field field = arrowSchema.GetFieldByIndex(fieldIndex);
                AppendDataFrameColumnFromArrowArray(field, arrowArray, ret);
                fieldIndex++;
            }
            return ret;
        }
 
        /// <summary>
        /// Returns an <see cref="IEnumerable{RecordBatch}"/> mostly without copying data
        /// </summary>
        public IEnumerable<RecordBatch> ToArrowRecordBatches()
        {
            Apache.Arrow.Schema.Builder schemaBuilder = new Apache.Arrow.Schema.Builder();
 
            int columnCount = Columns.Count;
            for (int i = 0; i < columnCount; i++)
            {
                DataFrameColumn column = Columns[i];
                Field field = column.GetArrowField();
                schemaBuilder.Field(field);
            }
 
            Schema schema = schemaBuilder.Build();
            List<Apache.Arrow.Array> arrays = new List<Apache.Arrow.Array>();
 
            int recordBatchLength = Int32.MaxValue;
            int numberOfRowsInThisRecordBatch = (int)Math.Min(recordBatchLength, Rows.Count);
            long numberOfRowsProcessed = 0;
 
            // Sometimes .NET for Spark passes in DataFrames with no rows. In those cases, we just return a RecordBatch with the right Schema and no rows
            do
            {
                for (int i = 0; i < columnCount; i++)
                {
                    DataFrameColumn column = Columns[i];
                    numberOfRowsInThisRecordBatch = (int)Math.Min(numberOfRowsInThisRecordBatch, column.GetMaxRecordBatchLength(numberOfRowsProcessed));
                }
                for (int i = 0; i < columnCount; i++)
                {
                    DataFrameColumn column = Columns[i];
                    arrays.Add(column.ToArrowArray(numberOfRowsProcessed, numberOfRowsInThisRecordBatch));
                }
                numberOfRowsProcessed += numberOfRowsInThisRecordBatch;
                yield return new RecordBatch(schema, arrays, numberOfRowsInThisRecordBatch);
            } while (numberOfRowsProcessed < Rows.Count);
        }
 
    }
}