File: DataView\DataViewConstructionUtils.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// A helper class to create data views based on the user-provided types.
    /// </summary>
    [BestFriend]
    internal static class DataViewConstructionUtils
    {
        private static readonly FuncStaticMethodInfo1<string, DataViewSchema.Annotations, AnnotationInfo> _getAnnotationInfoMethodInfo
            = new FuncStaticMethodInfo1<string, DataViewSchema.Annotations, AnnotationInfo>(GetAnnotationInfo<int>);
 
        public static IDataView CreateFromList<TRow>(IHostEnvironment env, IList<TRow> data,
            SchemaDefinition schemaDefinition = null)
            where TRow : class
        {
            Contracts.AssertValue(env);
            env.AssertValue(data);
            env.AssertValueOrNull(schemaDefinition);
            var internalSchemaDefn = schemaDefinition == null
                ? InternalSchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Read)
                : InternalSchemaDefinition.Create(typeof(TRow), schemaDefinition);
            return new ListDataView<TRow>(env, data, internalSchemaDefn);
        }
 
        public static StreamingDataView<TRow> CreateFromEnumerable<TRow>(IHostEnvironment env, IEnumerable<TRow> data,
            SchemaDefinition schemaDefinition = null)
            where TRow : class
        {
            Contracts.AssertValue(env);
            env.AssertValue(data);
            env.AssertValueOrNull(schemaDefinition);
            var internalSchemaDefn = schemaDefinition == null
                ? InternalSchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Read)
                : InternalSchemaDefinition.Create(typeof(TRow), schemaDefinition);
            return new StreamingDataView<TRow>(env, data, internalSchemaDefn);
        }
 
        public static StreamingDataView<TRow> CreateFromEnumerable<TRow>(IHostEnvironment env, IEnumerable<TRow> data,
            DataViewSchema schema)
            where TRow : class
        {
            Contracts.AssertValue(env);
            env.AssertValue(data);
            env.AssertValueOrNull(schema);
            schema = schema ?? new DataViewSchema.Builder().ToSchema();
            return new StreamingDataView<TRow>(env, data, GetInternalSchemaDefinition<TRow>(env, schema));
        }
 
        internal static SchemaDefinition GetSchemaDefinition<TRow>(IHostEnvironment env, DataViewSchema schema)
        {
            Contracts.AssertValue(env);
            env.AssertValue(schema);
 
            var schemaDefinition = SchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Read);
            foreach (var col in schema)
            {
                var name = col.Name;
                var schemaDefinitionCol = schemaDefinition.FirstOrDefault(c => c.ColumnName == name);
                if (schemaDefinitionCol == null)
                    throw env.Except($"Type should contain a member named {name}");
 
                //Always use column type from model as this type can be more specific.
                //This can be corner case:
                //For example, we can load an model whose schema contains Vector<Single, 38>
                //and define this field in input class as float[] without specific array length.
                schemaDefinitionCol.ColumnType = col.Type;
 
                var annotations = col.Annotations;
                if (annotations != null)
                {
                    foreach (var annotation in annotations.Schema)
                    {
                        var info = Utils.MarshalInvoke(_getAnnotationInfoMethodInfo, annotation.Type.RawType, annotation.Name, annotations);
                        schemaDefinitionCol.AddAnnotation(annotation.Name, info);
                    }
                }
            }
            return schemaDefinition;
        }
 
        private static InternalSchemaDefinition GetInternalSchemaDefinition<TRow>(IHostEnvironment env, DataViewSchema schema)
            where TRow : class
        {
            Contracts.AssertValue(env);
            env.AssertValue(schema);
            return InternalSchemaDefinition.Create(typeof(TRow), GetSchemaDefinition<TRow>(env, schema));
        }
 
        private static AnnotationInfo GetAnnotationInfo<T>(string kind, DataViewSchema.Annotations annotations)
        {
            T value = default;
            annotations.GetValue(kind, ref value);
            return new AnnotationInfo<T>(kind, value, annotations.Schema[kind].Type);
        }
 
        public static InputRow<TRow> CreateInputRow<TRow>(IHostEnvironment env, SchemaDefinition schemaDefinition = null)
            where TRow : class
        {
            Contracts.AssertValue(env);
            env.AssertValueOrNull(schemaDefinition);
            var internalSchemaDefn = schemaDefinition == null
                ? InternalSchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Read)
                : InternalSchemaDefinition.Create(typeof(TRow), schemaDefinition);
 
            return new InputRow<TRow>(env, internalSchemaDefn);
        }
 
        public static IDataView LoadPipeWithPredictor(IHostEnvironment env, Stream modelStream, IDataView view)
        {
            // Load transforms.
            var pipe = env.LoadTransforms(modelStream, view);
 
            // Load predictor (if present) and apply default scorer.
            // REVIEW: distinguish the case of predictor / no predictor?
            var predictor = env.LoadPredictorOrNull(modelStream);
            if (predictor != null)
            {
                var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, modelStream);
                pipe = roles != null
                    ? env.CreateDefaultScorer(new RoleMappedData(pipe, roles, opt: true), predictor)
                    : env.CreateDefaultScorer(new RoleMappedData(pipe, label: null, "Features"), predictor);
            }
            return pipe;
        }
 
        public sealed class InputRow<TRow> : InputRowBase<TRow>
            where TRow : class
        {
            private TRow _value;
 
            private long _position;
            public override long Position => _position;
 
            public InputRow(IHostEnvironment env, InternalSchemaDefinition schemaDef)
                : base(env, SchemaExtensions.MakeSchema(GetSchemaColumns(schemaDef)), schemaDef, MakePeeks(schemaDef), c => true)
            {
                _position = -1;
            }
 
            private static Delegate[] MakePeeks(InternalSchemaDefinition schemaDef)
            {
                var peeks = new Delegate[schemaDef.Columns.Length];
                for (var i = 0; i < peeks.Length; i++)
                {
                    var currentColumn = schemaDef.Columns[i];
                    peeks[i] = currentColumn.IsComputed
                        ? currentColumn.Generator
                        : ApiUtils.GeneratePeek<InputRow<TRow>, TRow>(currentColumn);
                }
                return peeks;
            }
 
            public void ExtractValues(TRow row)
            {
                Host.CheckValue(row, nameof(row));
                _value = row;
                _position++;
            }
 
            public override ValueGetter<DataViewRowId> GetIdGetter()
            {
                return IdGetter;
            }
 
            private void IdGetter(ref DataViewRowId val) => val = new DataViewRowId((ulong)Position, 0);
 
            protected override TRow GetCurrentRowObject()
            {
                Host.Check(Position >= 0, RowCursorUtils.FetchValueStateError);
                return _value;
            }
        }
 
        /// <summary>
        /// A row that consumes items of type <typeparamref name="TRow"/>, and provides an <see cref="DataViewRow"/>. This
        /// is in contrast to <see cref="IRowReadableAs{TRow}"/> which consumes a data view row and publishes them as the output type.
        /// </summary>
        /// <typeparam name="TRow">The input data type.</typeparam>
        public abstract class InputRowBase<TRow> : DataViewRow
            where TRow : class
        {
            private static readonly FuncInstanceMethodInfo1<InputRowBase<TRow>, Delegate, Delegate> _createDirectArrayGetterDelegateMethodInfo
                = FuncInstanceMethodInfo1<InputRowBase<TRow>, Delegate, Delegate>.Create(target => target.CreateDirectArrayGetterDelegate<int>);
 
            private static readonly FuncInstanceMethodInfo1<InputRowBase<TRow>, Delegate, Delegate> _createDirectVBufferGetterDelegateMethodInfo
                = FuncInstanceMethodInfo1<InputRowBase<TRow>, Delegate, Delegate>.Create(target => target.CreateDirectVBufferGetterDelegate<int>);
 
            private static readonly FuncInstanceMethodInfo1<InputRowBase<TRow>, Delegate, Delegate> _createDirectGetterDelegateMethodInfo
                = FuncInstanceMethodInfo1<InputRowBase<TRow>, Delegate, Delegate>.Create(target => target.CreateDirectGetterDelegate<int>);
 
            private static readonly FuncInstanceMethodInfo1<InputRowBase<TRow>, Delegate, DataViewType, Delegate> _createKeyGetterDelegateMethodInfo
                = FuncInstanceMethodInfo1<InputRowBase<TRow>, Delegate, DataViewType, Delegate>.Create(target => target.CreateKeyGetterDelegate<int>);
 
            private readonly int _colCount;
            private readonly Delegate[] _getters;
            protected readonly IHost Host;
 
            public override long Batch => 0;
 
            public override DataViewSchema Schema { get; }
 
            public InputRowBase(IHostEnvironment env, DataViewSchema schema, InternalSchemaDefinition schemaDef, Delegate[] peeks, Func<int, bool> predicate)
            {
                Contracts.AssertValue(env);
                Host = env.Register("Row");
                Host.AssertValue(schema);
                Host.AssertValue(schemaDef);
                Host.AssertValue(peeks);
                Host.AssertValue(predicate);
                Host.Assert(schema.Count == schemaDef.Columns.Length);
                Host.Assert(schema.Count == peeks.Length);
 
                _colCount = schema.Count;
                Schema = schema;
                _getters = new Delegate[_colCount];
                for (int c = 0; c < _colCount; c++)
                    _getters[c] = predicate(c) ? CreateGetter(schema[c].Type, schemaDef.Columns[c], peeks[c]) : null;
            }
 
            //private Delegate CreateGetter(SchemaProxy schema, int index, Delegate peek)
            private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Column column, Delegate peek)
            {
                var outputType = column.OutputType;
                var genericType = outputType;
                FuncInstanceMethodInfo1<InputRowBase<TRow>, Delegate, Delegate> del;
 
                if (outputType.IsArray)
                {
                    VectorDataViewType vectorType = colType as VectorDataViewType;
                    Host.Assert(vectorType != null);
 
                    // String[] -> ReadOnlyMemory<char>
                    if (outputType.GetElementType() == typeof(string))
                    {
                        Host.Assert(vectorType.ItemType is TextDataViewType);
                        return CreateConvertingArrayGetterDelegate<string, ReadOnlyMemory<char>>(peek, x => x != null ? x.AsMemory() : ReadOnlyMemory<char>.Empty);
                    }
 
                    // T[] -> VBuffer<T>
                    if (outputType.GetElementType().IsGenericType && outputType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>))
                        Host.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == vectorType.ItemType.RawType);
                    else
                        Host.Assert(outputType.GetElementType() == vectorType.ItemType.RawType);
                    del = _createDirectArrayGetterDelegateMethodInfo;
                    genericType = outputType.GetElementType();
                }
                else if (colType is VectorDataViewType vectorType)
                {
                    // VBuffer<T> -> VBuffer<T>
                    // REVIEW: Do we care about accommodating VBuffer<string> -> ReadOnlyMemory<char>?
                    Host.Assert(outputType.IsGenericType);
                    Host.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>));
                    Host.Assert(outputType.GetGenericArguments()[0] == vectorType.ItemType.RawType);
                    del = _createDirectVBufferGetterDelegateMethodInfo;
                    genericType = vectorType.ItemType.RawType;
                }
                else if (colType is PrimitiveDataViewType)
                {
                    if (outputType == typeof(string))
                    {
                        // String -> ReadOnlyMemory<char>
                        Host.Assert(colType is TextDataViewType);
                        return CreateConvertingGetterDelegate<String, ReadOnlyMemory<char>>(peek, x => x != null ? x.AsMemory() : ReadOnlyMemory<char>.Empty);
                    }
 
                    // T -> T
                    if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>))
                        Host.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType));
                    else
                        Host.Assert(colType.RawType == outputType);
 
                    if (!(colType is KeyDataViewType keyType))
                        del = _createDirectGetterDelegateMethodInfo;
                    else
                    {
                        var keyRawType = colType.RawType;
                        return Utils.MarshalInvoke(_createKeyGetterDelegateMethodInfo, this, keyRawType, peek, colType);
                    }
                }
                else if (DataViewTypeManager.Knows(colType))
                {
                    del = _createDirectGetterDelegateMethodInfo;
                }
                else
                {
                    // REVIEW: Is this even possible?
                    throw Host.ExceptNotSupp("Type '{0}' is not yet supported.", outputType.FullName);
                }
                return Utils.MarshalInvoke(del, this, genericType, peek);
            }
 
            // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower
            // than the 'direct' getter. We don't have good indication of this to the user, and the selection
            // of affected types is pretty arbitrary (signed integers and bools, but not uints and floats).
            private Delegate CreateConvertingArrayGetterDelegate<TSrc, TDst>(Delegate peekDel, Func<TSrc, TDst> convert)
            {
                var peek = peekDel as Peek<TRow, TSrc[]>;
                Host.AssertValue(peek);
                TSrc[] buf = default;
                return (ValueGetter<VBuffer<TDst>>)((ref VBuffer<TDst> dst) =>
                {
                    peek(GetCurrentRowObject(), Position, ref buf);
                    var n = Utils.Size(buf);
                    var dstEditor = VBufferEditor.Create(ref dst, n);
                    for (int i = 0; i < n; i++)
                        dstEditor.Values[i] = convert(buf[i]);
                    dst = dstEditor.Commit();
                });
            }
 
            private Delegate CreateConvertingGetterDelegate<TSrc, TDst>(Delegate peekDel, Func<TSrc, TDst> convert)
            {
                var peek = peekDel as Peek<TRow, TSrc>;
                Host.AssertValue(peek);
                TSrc buf = default;
                return (ValueGetter<TDst>)((ref TDst dst) =>
                {
                    peek(GetCurrentRowObject(), Position, ref buf);
                    dst = convert(buf);
                });
            }
 
            private Delegate CreateDirectArrayGetterDelegate<TDst>(Delegate peekDel)
            {
                var peek = peekDel as Peek<TRow, TDst[]>;
                Host.AssertValue(peek);
                TDst[] buf = null;
                return (ValueGetter<VBuffer<TDst>>)((ref VBuffer<TDst> dst) =>
                {
                    peek(GetCurrentRowObject(), Position, ref buf);
                    var n = Utils.Size(buf);
                    var dstEditor = VBufferEditor.Create(ref dst, n);
                    if (buf != null)
                        buf.AsSpan(0, n).CopyTo(dstEditor.Values);
                    dst = dstEditor.Commit();
                });
            }
 
            private Delegate CreateDirectVBufferGetterDelegate<TDst>(Delegate peekDel)
            {
                var peek = peekDel as Peek<TRow, VBuffer<TDst>>;
                Host.AssertValue(peek);
                VBuffer<TDst> buf = default;
                return (ValueGetter<VBuffer<TDst>>)((ref VBuffer<TDst> dst) =>
                {
                    // The peek for a VBuffer is just a simple assignment, so there is
                    // no copy going on in the peek, so we must do that as a second
                    // step to the destination.
                    peek(GetCurrentRowObject(), Position, ref buf);
                    buf.CopyTo(ref dst);
                });
            }
 
            private Delegate CreateDirectGetterDelegate<TDst>(Delegate peekDel)
            {
                var peek = peekDel as Peek<TRow, TDst>;
                Host.AssertValue(peek);
                return (ValueGetter<TDst>)((ref TDst dst) =>
                    peek(GetCurrentRowObject(), Position, ref dst));
            }
 
            private Delegate CreateKeyGetterDelegate<TDst>(Delegate peekDel, DataViewType colType)
            {
                // Make sure the function is dealing with key.
                KeyDataViewType keyType = colType as KeyDataViewType;
                Host.Check(keyType != null);
                // Following equations work only with unsigned integers.
                Host.Check(typeof(TDst) == typeof(ulong) || typeof(TDst) == typeof(uint) ||
                    typeof(TDst) == typeof(byte) || typeof(TDst) == typeof(bool));
 
                // Convert delegate function to a function which can fetch the underlying value.
                var peek = peekDel as Peek<TRow, TDst>;
                Host.AssertValue(peek);
 
                TDst rawKeyValue = default;
                ulong key = 0; // the raw key value as ulong
                ulong max = keyType.Count - 1;
                ulong result = 0; // the result as ulong
                ValueGetter<TDst> getter = (ref TDst dst) =>
                {
                    peek(GetCurrentRowObject(), Position, ref rawKeyValue);
                    key = (ulong)Convert.ChangeType(rawKeyValue, typeof(ulong));
                    if (key <= max)
                        result = key + 1;
                    else
                        result = 0;
                    dst = (TDst)Convert.ChangeType(result, typeof(TDst));
                };
                return getter;
            }
 
            protected abstract TRow GetCurrentRowObject();
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                CheckColumnInRange(column.Index);
                return _getters[column.Index] != null;
            }
 
            private void CheckColumnInRange(int columnIndex)
            {
                if (columnIndex < 0 || columnIndex >= _colCount)
                    throw Host.Except("Column index must be between 0 and {0}", _colCount);
            }
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                Host.CheckParam(column.Index <= _getters.Length && IsColumnActive(column), nameof(column), "requested column not active");
 
                var getter = _getters[column.Index];
                Contracts.AssertValue(getter);
                var fn = getter as ValueGetter<TValue>;
                if (fn == null)
                    throw Host.Except($"Invalid TValue in GetGetter for column #{column}: '{typeof(TValue)}', " +
                        $"expected type: '{getter.GetType().GetGenericArguments().First()}'.");
                return fn;
            }
        }
 
        /// <summary>
        /// The base class for the data view over items of user-defined type.
        /// </summary>
        /// <typeparam name="TRow">The user-defined data type.</typeparam>
        public abstract class DataViewBase<TRow> : IDataView
            where TRow : class
        {
            protected readonly IHost Host;
 
            private readonly DataViewSchema _schema;
            private readonly InternalSchemaDefinition _schemaDefn;
 
            // The array of generated methods that extract the fields of the current row object.
            private readonly Delegate[] _peeks;
 
            public abstract bool CanShuffle { get; }
 
            public DataViewSchema Schema => _schema;
 
            protected DataViewBase(IHostEnvironment env, string name, InternalSchemaDefinition schemaDefn)
            {
                Contracts.AssertValue(env);
                env.AssertNonWhiteSpace(name);
                Host = env.Register(name);
                Host.AssertValue(schemaDefn);
 
                _schemaDefn = schemaDefn;
                _schema = SchemaExtensions.MakeSchema(GetSchemaColumns(schemaDefn));
                int n = schemaDefn.Columns.Length;
                _peeks = new Delegate[n];
                for (var i = 0; i < n; i++)
                {
                    var currentColumn = schemaDefn.Columns[i];
                    _peeks[i] = currentColumn.IsComputed
                        ? currentColumn.Generator
                        : ApiUtils.GeneratePeek<DataViewBase<TRow>, TRow>(currentColumn);
                }
            }
 
            public abstract long? GetRowCount();
 
            public abstract DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null);
 
            public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
            {
                return new[] { GetRowCursor(columnsNeeded, rand) };
            }
 
            public sealed class WrappedCursor : DataViewRowCursor
            {
                private readonly DataViewCursorBase _toWrap;
 
                public WrappedCursor(DataViewCursorBase toWrap) => _toWrap = toWrap;
 
                public override long Position => _toWrap.Position;
                public override long Batch => _toWrap.Batch;
                public override DataViewSchema Schema => _toWrap.Schema;
 
                protected override void Dispose(bool disposing)
                {
                    if (disposing)
                        _toWrap.Dispose();
                }
 
                /// <summary>
                /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
                /// This throws if the column is not active in this row, or if the type
                /// <typeparamref name="TValue"/> differs from this column's type.
                /// </summary>
                /// <typeparam name="TValue"> is the column's content type.</typeparam>
                /// <param name="column"> is the output column whose getter should be returned.</param>
                public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
                    => _toWrap.GetGetter<TValue>(column);
 
                public override ValueGetter<DataViewRowId> GetIdGetter() => _toWrap.GetIdGetter();
 
                /// <summary>
                /// Returns whether the given column is active in this row.
                /// </summary>
                public override bool IsColumnActive(DataViewSchema.Column column) => _toWrap.IsColumnActive(column);
                public override bool MoveNext() => _toWrap.MoveNext();
            }
 
            public abstract class DataViewCursorBase : InputRowBase<TRow>
            {
                // There is no real concept of multiple inheritance and for various reasons it was better to
                // descend from the row class as opposed to wrapping it, so much of this class is regrettably
                // copied from RootCursorBase.
 
                protected readonly DataViewBase<TRow> DataView;
                protected readonly IChannel Ch;
                private long _position;
                private bool _disposed;
 
                /// <summary>
                /// Zero-based position of the cursor.
                /// </summary>
                public override long Position => _position;
 
                protected DataViewCursorBase(IHostEnvironment env, DataViewBase<TRow> dataView,
                    Func<int, bool> predicate)
                    : base(env, dataView.Schema, dataView._schemaDefn, dataView._peeks, predicate)
                {
                    Contracts.AssertValue(env);
                    Ch = env.Start("Cursor");
                    Ch.AssertValue(dataView);
                    Ch.AssertValue(predicate);
 
                    DataView = dataView;
                    _position = -1;
                }
 
                /// <summary>
                /// Convenience property for checking whether the cursor is in a good state where values
                /// can be retrieved, that is, whenever <see cref="Position"/> is non-negative.
                /// </summary>
                protected bool IsGood => Position >= 0;
 
                protected sealed override void Dispose(bool disposing)
                {
                    if (_disposed)
                        return;
                    if (disposing)
                    {
                        Ch.Dispose();
                        _position = -1;
                    }
                    _disposed = true;
                    base.Dispose(disposing);
                }
 
                public bool MoveNext()
                {
                    if (_disposed)
                        return false;
 
                    if (MoveNextCore())
                    {
                        _position++;
                        return true;
                    }
 
                    Dispose();
                    return false;
                }
 
                /// <summary>
                /// Core implementation of <see cref="MoveNext"/>, called if no prior call to this method
                /// has returned <see langword="false"/>.
                /// </summary>
                protected abstract bool MoveNextCore();
            }
        }
 
        /// <summary>
        /// An in-memory data view based on the IList of data.
        /// Supports shuffling.
        /// </summary>
        private sealed class ListDataView<TRow> : DataViewBase<TRow>
            where TRow : class
        {
            private readonly IList<TRow> _data;
 
            public ListDataView(IHostEnvironment env, IList<TRow> data, InternalSchemaDefinition schemaDefn)
                : base(env, "ListDataView", schemaDefn)
            {
                Host.CheckValue(data, nameof(data));
                _data = data;
            }
 
            public override bool CanShuffle => true;
 
            public override long? GetRowCount()
            {
                return _data.Count;
            }
 
            public override DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
            {
                var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
                return new WrappedCursor(new Cursor(Host, "ListDataView", this, predicate, rand));
            }
 
            private sealed class Cursor : DataViewCursorBase
            {
                private readonly int[] _permutation;
                private readonly IList<TRow> _data;
 
                private int Index
                {
                    get { return _permutation == null ? (int)Position : _permutation[(int)Position]; }
                }
 
                public Cursor(IHostEnvironment env, string name, ListDataView<TRow> dataView,
                    Func<int, bool> predicate, Random rand)
                    : base(env, dataView, predicate)
                {
                    Ch.AssertValueOrNull(rand);
                    _data = dataView._data;
                    if (rand != null)
                        _permutation = Utils.GetRandomPermutation(rand, dataView._data.Count);
                }
 
                public override ValueGetter<DataViewRowId> GetIdGetter()
                {
                    if (_permutation == null)
                    {
                        return
                            (ref DataViewRowId val) =>
                            {
                                Ch.Check(IsGood, RowCursorUtils.FetchValueStateError);
                                val = new DataViewRowId((ulong)Position, 0);
                            };
                    }
                    else
                    {
                        return
                            (ref DataViewRowId val) =>
                            {
                                Ch.Check(IsGood, RowCursorUtils.FetchValueStateError);
                                val = new DataViewRowId((ulong)Index, 0);
                            };
                    }
                }
 
                protected override TRow GetCurrentRowObject()
                {
                    Ch.Check(0 <= Position && Position < _data.Count, RowCursorUtils.FetchValueStateError);
                    return _data[Index];
                }
 
                protected override bool MoveNextCore()
                {
                    Ch.Assert(Position < _data.Count);
                    return Position + 1 < _data.Count;
                }
            }
        }
 
        /// <summary>
        /// An in-memory data view based on the IEnumerable of data.
        /// Doesn't support shuffling.
        /// </summary>
        internal sealed class StreamingDataView<TRow> : DataViewBase<TRow>
            where TRow : class
        {
            private readonly IEnumerable<TRow> _data;
 
            public StreamingDataView(IHostEnvironment env, IEnumerable<TRow> data, InternalSchemaDefinition schemaDefn)
                : base(env, "StreamingDataView", schemaDefn)
            {
                Contracts.CheckValue(data, nameof(data));
                _data = data;
            }
 
            public override bool CanShuffle => false;
 
            public override long? GetRowCount()
                => (_data as ICollection<TRow>)?.Count;
 
            public override DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
            {
                var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
                return new WrappedCursor(new Cursor(Host, this, predicate));
            }
 
            private sealed class Cursor : DataViewCursorBase
            {
                private readonly IEnumerator<TRow> _enumerator;
                private TRow _currentRow;
 
                public Cursor(IHostEnvironment env, StreamingDataView<TRow> dataView, Func<int, bool> predicate)
                    : base(env, dataView, predicate)
                {
                    _enumerator = dataView._data.GetEnumerator();
                    _currentRow = null;
                }
 
                public override ValueGetter<DataViewRowId> GetIdGetter()
                {
                    return
                        (ref DataViewRowId val) =>
                        {
                            Ch.Check(IsGood, RowCursorUtils.FetchValueStateError);
                            val = new DataViewRowId((ulong)Position, 0);
                        };
                }
 
                protected override TRow GetCurrentRowObject()
                {
                    return _currentRow;
                }
 
                protected override bool MoveNextCore()
                {
                    var result = _enumerator.MoveNext();
                    _currentRow = result ? _enumerator.Current : null;
                    if (result && _currentRow == null)
                        throw Ch.Except("Encountered null when iterating over data, this is not supported.");
                    return result;
                }
            }
        }
 
        /// <summary>
        /// This represents the 'infinite data view' over one (mutable) user-defined object.
        /// The 'current row' object can be updated at any time, this will affect all the
        /// newly created cursors, but not the ones already existing.
        /// </summary>
        public sealed class SingleRowLoopDataView<TRow> : DataViewBase<TRow>
            where TRow : class
        {
            private TRow _current;
 
            public SingleRowLoopDataView(IHostEnvironment env, InternalSchemaDefinition schemaDefn)
                : base(env, "SingleRowLoopDataView", schemaDefn)
            {
            }
 
            public override bool CanShuffle => false;
 
            public override long? GetRowCount() => null;
 
            public void SetCurrentRowObject(TRow value)
            {
                Host.AssertValue(value);
                _current = value;
            }
 
            public override DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
            {
                Contracts.Assert(_current != null, "The current object must be set prior to cursoring");
                var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
                return new WrappedCursor(new Cursor(Host, this, predicate));
            }
 
            private sealed class Cursor : DataViewCursorBase
            {
                private readonly TRow _currentRow;
 
                public Cursor(IHostEnvironment env, SingleRowLoopDataView<TRow> dataView, Func<int, bool> predicate)
                    : base(env, dataView, predicate)
                {
                    _currentRow = dataView._current;
                }
 
                public override ValueGetter<DataViewRowId> GetIdGetter()
                {
                    return
                        (ref DataViewRowId val) =>
                        {
                            Ch.Check(IsGood, RowCursorUtils.FetchValueStateError);
                            val = new DataViewRowId((ulong)Position, 0);
                        };
                }
 
                protected override TRow GetCurrentRowObject() => _currentRow;
 
                protected override bool MoveNextCore() => true;
            }
        }
 
        [BestFriend]
        internal static DataViewSchema.DetachedColumn[] GetSchemaColumns(InternalSchemaDefinition schemaDefn)
        {
            Contracts.AssertValue(schemaDefn);
            var columns = new DataViewSchema.DetachedColumn[schemaDefn.Columns.Length];
            for (int i = 0; i < columns.Length; i++)
            {
                var col = schemaDefn.Columns[i];
                var meta = new DataViewSchema.Annotations.Builder();
                foreach (var kvp in col.Annotations)
                    meta.Add(kvp.Value.Kind, kvp.Value.AnnotationType, kvp.Value.GetGetterDelegate());
                columns[i] = new DataViewSchema.DetachedColumn(col.ColumnName, col.ColumnType, meta.ToAnnotations());
            }
 
            return columns;
        }
    }
 
    /// <summary>
    /// A single instance of annotation information, associated with a column.
    /// </summary>
    internal abstract partial class AnnotationInfo
    {
        /// <summary>
        /// The type of the annotation.
        /// </summary>
        public DataViewType AnnotationType;
        /// <summary>
        /// The string identifier of the annotation. Some identifiers have special meaning,
        /// like "SlotNames", but any other identifiers can be used.
        /// </summary>
        public readonly string Kind;
 
        public abstract ValueGetter<TDst> GetGetter<TDst>();
 
        internal abstract Delegate GetGetterDelegate();
 
        private protected AnnotationInfo(string kind, DataViewType annotationType)
        {
            Contracts.AssertValueOrNull(annotationType);
            Contracts.AssertNonEmpty(kind);
            AnnotationType = annotationType;
            Kind = kind;
        }
    }
 
    /// <summary>
    /// Strongly-typed version of <see cref="AnnotationInfo"/>, that contains the actual value of the annotation.
    /// </summary>
    /// <typeparam name="T">Type of the annotation value.</typeparam>
    internal sealed class AnnotationInfo<T> : AnnotationInfo
    {
        private static readonly FuncInstanceMethodInfo1<AnnotationInfo<T>, Delegate> _getArrayGetterMethodInfo
            = FuncInstanceMethodInfo1<AnnotationInfo<T>, Delegate>.Create(target => target.GetArrayGetter<int>);
 
        private static readonly FuncInstanceMethodInfo1<AnnotationInfo<T>, Delegate> _getGetterCoreMethodInfo
            = FuncInstanceMethodInfo1<AnnotationInfo<T>, Delegate>.Create(target => target.GetGetterCore<int>);
 
        private static readonly FuncInstanceMethodInfo1<AnnotationInfo<T>, Delegate> _getVBufferGetterMethodInfo
            = FuncInstanceMethodInfo1<AnnotationInfo<T>, Delegate>.Create(target => target.GetVBufferGetter<int>);
 
        public readonly T Value;
 
        /// <summary>
        /// Constructor for annotation of value type T.
        /// </summary>
        /// <param name="kind">The string identifier of the annotation. Some identifiers have special meaning,
        /// like "SlotNames", but any other identifiers can be used.</param>
        /// <param name="value">Annotation value.</param>
        /// <param name="annotationType">Type of the annotation.</param>
        public AnnotationInfo(string kind, T value, DataViewType annotationType = null)
            : base(kind, annotationType)
        {
            Contracts.Assert(value != null);
            bool isVector;
            Type itemType;
            InternalSchemaDefinition.GetVectorAndItemType("annotation value", typeof(T), null, out isVector, out itemType);
 
            if (annotationType == null)
            {
                // Infer a type as best we can.
                var primitiveItemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType);
                annotationType = isVector ? new VectorDataViewType(primitiveItemType) : (DataViewType)primitiveItemType;
            }
            else
            {
                // Make sure that the types are compatible with the declared type, including whether it is a vector type.
                VectorDataViewType annotationVectorType = annotationType as VectorDataViewType;
                bool annotationIsVector = annotationVectorType != null;
                if (isVector != annotationIsVector)
                {
                    throw Contracts.Except("Value inputted is supposed to be {0}, but type of Annotationinfo is {1}",
                        isVector ? "vector" : "scalar", annotationIsVector ? "vector" : "scalar");
                }
 
                DataViewType annotationItemType = annotationVectorType?.ItemType ?? annotationType;
                if (itemType != annotationItemType.RawType)
                {
                    throw Contracts.Except(
                        "Value inputted is supposed to have Type {0}, but type of Annotationinfo has {1}",
                        itemType.ToString(), annotationItemType.RawType.ToString());
                }
            }
            AnnotationType = annotationType;
            Value = value;
        }
 
        public override ValueGetter<TDst> GetGetter<TDst>()
        {
            var typeT = typeof(T);
            if (typeT.IsArray)
            {
                Contracts.Assert(AnnotationType is VectorDataViewType);
                Contracts.Check(typeof(TDst).IsGenericType && typeof(TDst).GetGenericTypeDefinition() == typeof(VBuffer<>));
                var itemType = typeT.GetElementType();
                var dstItemType = typeof(TDst).GetGenericArguments()[0];
 
                // String[] -> VBuffer<ReadOnlyMemory<char>>
                if (itemType == typeof(string))
                {
                    Contracts.Check(dstItemType == typeof(ReadOnlyMemory<char>));
 
                    ValueGetter<VBuffer<ReadOnlyMemory<char>>> method = GetStringArray;
                    return method as ValueGetter<TDst>;
                }
 
                // T[] -> VBuffer<T>
                Contracts.Check(itemType == dstItemType);
 
                return Utils.MarshalInvoke(_getArrayGetterMethodInfo, this, dstItemType) as ValueGetter<TDst>;
            }
            if (AnnotationType is VectorDataViewType annotationVectorType)
            {
                // VBuffer<T> -> VBuffer<T>
                // REVIEW: Do we care about accommodating VBuffer<string> -> VBuffer<ReadOnlyMemory<char>>?
 
                Contracts.Assert(typeT.IsGenericType);
                Contracts.Check(typeof(TDst).IsGenericType);
                Contracts.Assert(typeT.GetGenericTypeDefinition() == typeof(VBuffer<>));
                Contracts.Check(typeof(TDst).GetGenericTypeDefinition() == typeof(VBuffer<>));
                var dstItemType = typeof(TDst).GetGenericArguments()[0];
                var itemType = typeT.GetGenericArguments()[0];
                Contracts.Assert(itemType == annotationVectorType.ItemType.RawType);
                Contracts.Check(itemType == dstItemType);
 
                return Utils.MarshalInvoke(_getVBufferGetterMethodInfo, this, annotationVectorType.ItemType.RawType) as ValueGetter<TDst>;
            }
            if (AnnotationType is PrimitiveDataViewType)
            {
                if (typeT == typeof(string))
                {
                    // String -> ReadOnlyMemory<char>
                    Contracts.Assert(AnnotationType is TextDataViewType);
                    ValueGetter<ReadOnlyMemory<char>> m = GetString;
                    return m as ValueGetter<TDst>;
                }
                // T -> T
                Contracts.Assert(AnnotationType.RawType == typeT);
                return GetDirectValue;
            }
            throw Contracts.ExceptNotImpl("Type '{0}' is not yet supported.", typeT.FullName);
        }
 
        // We want to use MarshalInvoke instead of adding custom Reflection logic for calling GetGetter<TDst>
        private Delegate GetGetterCore<TDst>()
        {
            return GetGetter<TDst>();
        }
 
        internal override Delegate GetGetterDelegate()
        {
            return Utils.MarshalInvoke(_getGetterCoreMethodInfo, this, AnnotationType.RawType);
        }
 
        private void GetStringArray(ref VBuffer<ReadOnlyMemory<char>> dst)
        {
            var value = (string[])(object)Value;
            var n = Utils.Size(value);
            var dstEditor = VBufferEditor.Create(ref dst, n);
 
            for (int i = 0; i < n; i++)
                dstEditor.Values[i] = value[i].AsMemory();
 
            dst = dstEditor.Commit();
        }
 
        private ValueGetter<VBuffer<TDst>> GetArrayGetter<TDst>()
        {
            var value = (TDst[])(object)Value;
            var n = Utils.Size(value);
            return (ref VBuffer<TDst> dst) =>
            {
                var dstEditor = VBufferEditor.Create(ref dst, n);
                if (value != null)
                    value.AsSpan(0, n).CopyTo(dstEditor.Values);
                dst = dstEditor.Commit();
            };
        }
 
        private ValueGetter<VBuffer<TDst>> GetVBufferGetter<TDst>()
        {
            var castValue = (VBuffer<TDst>)(object)Value;
            return (ref VBuffer<TDst> dst) => castValue.CopyTo(ref dst);
        }
 
        private void GetString(ref ReadOnlyMemory<char> dst)
        {
            dst = ((string)(object)Value).AsMemory();
        }
 
        private void GetDirectValue<TDst>(ref TDst dst)
        {
            dst = (TDst)(object)Value;
        }
    }
}