File: Utilities\ColumnCursor.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// Extension methods that allow to extract values of a single column of an <see cref="IDataView"/> as an
    /// <see cref="IEnumerable{T}"/>.
    /// </summary>
    public static class ColumnCursorExtensions
    {
 
        /// <summary>
        /// Extract all values of one column of the data view in a form of an <see cref="IEnumerable{T}"/>.
        /// </summary>
        /// <typeparam name="T">The type of the values. This must match the actual column type.</typeparam>
        /// <param name="data">The data view to get the column from.</param>
        /// <param name="columnName">The name of the column to be extracted.</param>
 
        public static IEnumerable<T> GetColumn<T>(this IDataView data, string columnName)
            => GetColumn<T>(data, data.Schema[columnName]);
 
        /// <summary>
        /// Extract all values of one column of the data view in a form of an <see cref="IEnumerable{T}"/>.
        /// </summary>
        /// <typeparam name="T">The type of the values. This must match the actual column type.</typeparam>
        /// <param name="data">The data view to get the column from.</param>
        /// <param name="column">The column to be extracted.</param>
        public static IEnumerable<T> GetColumn<T>(this IDataView data, DataViewSchema.Column column)
        {
            Contracts.CheckValue(data, nameof(data));
            Contracts.CheckNonEmpty(column.Name, nameof(column));
 
            var colIndex = column.Index;
            var colType = column.Type;
            var colName = column.Name;
 
            // Use column index as the principle address of the specified input column and check if that address in data contains
            // the column indicated.
            if (data.Schema[colIndex].Name != colName || data.Schema[colIndex].Type != colType)
                throw Contracts.ExceptParam(nameof(column), string.Format("column with name {0}, type {1}, and index {2} cannot be found in {3}",
                    colName, colType, colIndex, nameof(data)));
 
            // There are two decisions that we make here:
            // - Is the T an array type?
            //     - If yes, we need to map VBuffer to array and densify.
            //     - If no, this is not needed.
            // - Does T (or item type of T if it's an array) equal to the data view type?
            //     - If this is the same type, we can map directly.
            //     - Otherwise, we need a conversion delegate.
 
            if (colType.RawType == typeof(T))
            {
                // Direct mapping is possible.
                return GetColumnDirect<T>(data, colIndex);
            }
            else if (typeof(T) == typeof(string) && colType is TextDataViewType)
            {
                // Special case of ROM<char> to string conversion.
                Delegate convert = (Func<ReadOnlyMemory<char>, string>)((ReadOnlyMemory<char> txt) => txt.ToString());
                Func<IDataView, int, Func<int, T>, IEnumerable<T>> del = GetColumnConvert;
                var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(typeof(T), colType.RawType);
                return (IEnumerable<T>)(meth.Invoke(null, new object[] { data, colIndex, convert }));
            }
            else if (typeof(T).IsArray)
            {
                // Output is an array type.
                if (!(colType is VectorDataViewType colVectorType))
                    throw Contracts.ExceptParam(nameof(column), string.Format("Cannot load vector type, {0}, specified in {1} to the user-defined type, {2}.", column.Type, nameof(column), typeof(T)));
                var elementType = typeof(T).GetElementType();
                if (elementType == colVectorType.ItemType.RawType)
                {
                    // Direct mapping of items.
                    Func<IDataView, int, IEnumerable<int[]>> del = GetColumnArrayDirect<int>;
                    var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(elementType);
                    return (IEnumerable<T>)meth.Invoke(null, new object[] { data, colIndex });
                }
                else if (elementType == typeof(string) && colVectorType.ItemType is TextDataViewType)
                {
                    // Conversion of DvText items to string items.
                    Delegate convert = (Func<ReadOnlyMemory<char>, string>)((ReadOnlyMemory<char> txt) => txt.ToString());
                    Func<IDataView, int, Func<int, long>, IEnumerable<long[]>> del = GetColumnArrayConvert;
                    var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(elementType, colVectorType.ItemType.RawType);
                    return (IEnumerable<T>)meth.Invoke(null, new object[] { data, colIndex, convert });
                }
                // Fall through to the failure.
            }
 
            throw Contracts.ExceptParam(nameof(column), string.Format("Cannot map column (name: {0}, type: {1}) in {2} to the user-defined type, {3}.",
                column.Name, column.Type, nameof(data), typeof(T)));
        }
 
        private static IEnumerable<T> GetColumnDirect<T>(IDataView data, int col)
        {
            Contracts.AssertValue(data);
            Contracts.Assert(0 <= col && col < data.Schema.Count);
 
            var column = data.Schema[col];
            using (var cursor = data.GetRowCursor(column))
            {
                var getter = cursor.GetGetter<T>(column);
                T curValue = default;
                while (cursor.MoveNext())
                {
                    getter(ref curValue);
                    yield return curValue;
                }
            }
        }
 
        private static IEnumerable<TOut> GetColumnConvert<TOut, TData>(IDataView data, int col, Func<TData, TOut> convert)
        {
            Contracts.AssertValue(data);
            Contracts.Assert(0 <= col && col < data.Schema.Count);
 
            var column = data.Schema[col];
            using (var cursor = data.GetRowCursor(column))
            {
                var getter = cursor.GetGetter<TData>(column);
                TData curValue = default;
                while (cursor.MoveNext())
                {
                    getter(ref curValue);
                    yield return convert(curValue);
                }
            }
        }
 
        private static IEnumerable<T[]> GetColumnArrayDirect<T>(IDataView data, int col)
        {
            Contracts.AssertValue(data);
            Contracts.Assert(0 <= col && col < data.Schema.Count);
 
            var column = data.Schema[col];
            using (var cursor = data.GetRowCursor(column))
            {
                var getter = cursor.GetGetter<VBuffer<T>>(column);
                VBuffer<T> curValue = default;
                while (cursor.MoveNext())
                {
                    getter(ref curValue);
                    // REVIEW: should we introduce the 'reuse array' logic here?
                    // For now it re-creates the array and densifies.
                    var dst = new T[curValue.Length];
                    curValue.CopyTo(dst);
                    yield return dst;
                }
            }
        }
 
        private static IEnumerable<TOut[]> GetColumnArrayConvert<TOut, TData>(IDataView data, int col, Func<TData, TOut> convert)
        {
            Contracts.AssertValue(data);
            Contracts.Assert(0 <= col && col < data.Schema.Count);
 
            var column = data.Schema[col];
            using (var cursor = data.GetRowCursor(column))
            {
                var getter = cursor.GetGetter<VBuffer<TData>>(column);
                VBuffer<TData> curValue = default;
                while (cursor.MoveNext())
                {
                    getter(ref curValue);
                    // REVIEW: should we introduce the 'reuse array' logic here?
                    // For now it re-creates the array and densifies.
                    var dst = new TOut[curValue.Length];
                    foreach (var kvp in curValue.Items(all: false))
                        dst[kvp.Key] = convert(kvp.Value);
                    yield return dst;
                }
            }
        }
    }
}