File: DataView\ArrayDataViewBuilder.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    using BitArray = System.Collections.BitArray;
 
    /// <summary>
    /// This is a class for composing an in memory IDataView.
    /// </summary>
    [BestFriend]
    internal sealed class ArrayDataViewBuilder
    {
        private readonly IHost _host;
        private readonly List<Column> _columns;
        private readonly List<string> _names;
        private readonly Dictionary<string, ValueGetter<VBuffer<ReadOnlyMemory<char>>>> _getSlotNames;
        private readonly Dictionary<string, ValueGetter<VBuffer<ReadOnlyMemory<char>>>> _getKeyValues;
 
        private int? RowCount
        {
            get
            {
                if (_columns.Count == 0)
                    return null;
                return _columns[0].Length;
            }
        }
 
        public ArrayDataViewBuilder(IHostEnvironment env)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register("ArrayDataViewBuilder");
 
            _columns = new List<Column>();
            _names = new List<string>();
            _getSlotNames = new Dictionary<string, ValueGetter<VBuffer<ReadOnlyMemory<char>>>>();
            _getKeyValues = new Dictionary<string, ValueGetter<VBuffer<ReadOnlyMemory<char>>>>();
        }
 
        /// <summary>
        /// Verifies that the input array to one of the add routines is of the same length
        /// as previously added arrays, assuming there were any.
        /// </summary>
        private void CheckLength<T>(string name, T[] values)
        {
            _host.CheckValue(name, nameof(name));
            _host.CheckValue(values, nameof(values));
            if (_columns.Count > 0 && values.Length != _columns[0].Length)
                throw _host.Except("Previous inputs were of length {0}, but new input is of length {1}", _columns[0].Length, values.Length);
        }
 
        /// <summary>
        /// Constructs a new column from an array where values are copied to output simply
        /// by being assigned. Output values are returned simply by being assigned, so the
        /// type <typeparamref name="T"/> should be a type where assigning to a different
        /// value does not compromise the immutability of the source object (so, for example,
        /// a scalar, string, or <c>ReadOnlyMemory</c> would be perfectly acceptable, but a
        /// <c>HashSet</c> or <c>VBuffer</c> would not be).
        /// </summary>
        public void AddColumn<T>(string name, PrimitiveDataViewType type, params T[] values)
        {
            _host.CheckParam(type != null && type.RawType == typeof(T), nameof(type));
            CheckLength(name, values);
            _columns.Add(new AssignmentColumn<T>(type, values));
            _names.Add(name);
        }
 
        /// <summary>
        /// Constructs a new key column from an array where values are copied to output simply
        /// by being assigned.
        /// </summary>
        /// <param name="name">The name of the column.</param>
        /// <param name="getKeyValues">The delegate that does a reverse lookup based upon the given key. This is for annotation creation</param>
        /// <param name="keyCount">The count of unique keys specified in values</param>
        /// <param name="values">The values to add to the column. Note that since this is creating a <see cref="KeyDataViewType"/> column, the values will be offset by 1.</param>
        public void AddColumn<T1>(string name, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getKeyValues, ulong keyCount, params T1[] values)
        {
            _host.CheckValue(getKeyValues, nameof(getKeyValues));
            _host.CheckParam(keyCount > 0, nameof(keyCount));
            CheckLength(name, values);
            values.GetType().GetElementType().TryGetDataKind(out InternalDataKind kind);
            _columns.Add(new AssignmentColumn<T1>(new KeyDataViewType(kind.ToType(), keyCount), values));
            _getKeyValues.Add(name, getKeyValues);
            _names.Add(name);
        }
 
        /// <summary>
        /// Creates a column with slot names from arrays. The added column will be re-interpreted as a buffer.
        /// </summary>
        public void AddColumn<T>(string name, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getNames, PrimitiveDataViewType itemType, params T[][] values)
        {
            _host.CheckValue(getNames, nameof(getNames));
            _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType));
            CheckLength(name, values);
            var col = new ArrayToVBufferColumn<T>(itemType, values);
            _columns.Add(col);
            _getSlotNames.Add(name, getNames);
            _names.Add(name);
        }
 
        /// <summary>
        /// Creates a column from arrays. The added column will be re-interpreted as a buffer.
        /// </summary>
        public void AddColumn<T>(string name, PrimitiveDataViewType itemType, params T[][] values)
        {
            _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType));
            CheckLength(name, values);
            _columns.Add(new ArrayToVBufferColumn<T>(itemType, values));
            _names.Add(name);
        }
 
        /// <summary>
        /// Creates a column with slot names from arrays. The added column will be re-interpreted as a buffer and possibly sparsified.
        /// </summary>
        public void AddColumn<T>(string name, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getNames, PrimitiveDataViewType itemType, Combiner<T> combiner, params T[][] values)
        {
            _host.CheckValue(getNames, nameof(getNames));
            _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType));
            CheckLength(name, values);
            var col = new ArrayToSparseVBufferColumn<T>(itemType, combiner, values);
            _columns.Add(col);
            _getSlotNames.Add(name, getNames);
            _names.Add(name);
        }
 
        /// <summary>
        /// Creates a column from arrays. The added column will be re-interpreted as a buffer and possibly sparsified.
        /// </summary>
        public void AddColumn<T>(string name, PrimitiveDataViewType itemType, Combiner<T> combiner, params T[][] values)
        {
            _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType));
            CheckLength(name, values);
            _columns.Add(new ArrayToSparseVBufferColumn<T>(itemType, combiner, values));
            _names.Add(name);
        }
 
        /// <summary>
        /// Adds a VBuffer{T} valued column.
        /// </summary>
        public void AddColumn<T>(string name, PrimitiveDataViewType itemType, params VBuffer<T>[] values)
        {
            _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType));
            CheckLength(name, values);
            _columns.Add(new VBufferColumn<T>(itemType, values));
            _names.Add(name);
        }
 
        /// <summary>
        /// Adds a VBuffer{T} valued column.
        /// </summary>
        public void AddColumn<T>(string name, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getNames, PrimitiveDataViewType itemType, params VBuffer<T>[] values)
        {
            _host.CheckValue(getNames, nameof(getNames));
            _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType));
            CheckLength(name, values);
            _columns.Add(new VBufferColumn<T>(itemType, values));
            _getSlotNames.Add(name, getNames);
            _names.Add(name);
        }
 
        /// <summary>
        /// Adds a <c>ReadOnlyMemory</c> valued column from an array of strings.
        /// </summary>
        public void AddColumn(string name, params string[] values)
        {
            CheckLength(name, values);
            _columns.Add(new StringToTextColumn(values));
            _names.Add(name);
        }
 
        /// <summary>
        /// Constructs a data view from the columns added so far. Note that it is perfectly acceptable
        /// to continue adding columns to the builder, but these additions will not be reflected in the
        /// returned dataview.
        /// </summary>
        /// <param name="rowCount"></param>
        public IDataView GetDataView(int? rowCount = null)
        {
            if (rowCount.HasValue)
            {
                _host.Check(!RowCount.HasValue || RowCount.Value == rowCount.Value, "Specified row count incompatible with existing columns");
                return new DataView(_host, this, rowCount.Value);
            }
            _host.Check(_columns.Count > 0, "Cannot construct data-view with neither any columns nor a specified row count");
            return new DataView(_host, this, RowCount.Value);
        }
 
        private sealed class DataView : IDataView
        {
            private readonly int _rowCount;
            private readonly Column[] _columns;
            private readonly DataViewSchema _schema;
            private readonly IHost _host;
 
            public DataViewSchema Schema { get { return _schema; } }
 
            public long? GetRowCount() { return _rowCount; }
 
            public bool CanShuffle { get { return true; } }
 
            public DataView(IHostEnvironment env, ArrayDataViewBuilder builder, int rowCount)
            {
                Contracts.AssertValue(env, "env");
                _host = env.Register("ArrayDataView");
 
                _host.AssertValue(builder);
                _host.Assert(rowCount >= 0);
                _host.Assert(builder._names.Count == builder._columns.Count);
                _columns = builder._columns.ToArray();
 
                var schemaBuilder = new DataViewSchema.Builder();
                for (int i = 0; i < _columns.Length; i++)
                {
                    var meta = new DataViewSchema.Annotations.Builder();
 
                    if (builder._getSlotNames.TryGetValue(builder._names[i], out var slotNamesGetter))
                        meta.AddSlotNames(_columns[i].Type.GetVectorSize(), slotNamesGetter);
 
                    if (builder._getKeyValues.TryGetValue(builder._names[i], out var keyValueGetter))
                        meta.AddKeyValues(_columns[i].Type.GetKeyCountAsInt32(_host), TextDataViewType.Instance, keyValueGetter);
                    schemaBuilder.AddColumn(builder._names[i], _columns[i].Type, meta.ToAnnotations());
                }
 
                _schema = schemaBuilder.ToSchema();
                _rowCount = rowCount;
            }
 
            public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
            {
                var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
 
                _host.CheckValueOrNull(rand);
                return new Cursor(_host, this, predicate, rand);
            }
 
            public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
            {
                var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
 
                _host.CheckValueOrNull(rand);
                return new DataViewRowCursor[] { new Cursor(_host, this, predicate, rand) };
            }
 
            private sealed class Cursor : RootCursorBase
            {
                private readonly DataView _view;
                private readonly BitArray _active;
                private readonly int[] _indices;
 
                public override DataViewSchema Schema => _view.Schema;
 
                public override long Batch
                {
                    // REVIEW: Implement cursor set support.
                    get { return 0; }
                }
 
                public Cursor(IChannelProvider provider, DataView view, Func<int, bool> predicate, Random rand)
                    : base(provider)
                {
                    Ch.AssertValue(view);
                    Ch.AssertValueOrNull(rand);
                    Ch.Assert(view.Schema.Count >= 0);
 
                    _view = view;
                    _active = new BitArray(view.Schema.Count);
                    if (predicate == null)
                        _active.SetAll(true);
                    else
                    {
                        for (int i = 0; i < view.Schema.Count; ++i)
                            _active[i] = predicate(i);
                    }
                    if (rand != null)
                        _indices = Utils.GetRandomPermutation(rand, view._rowCount);
                }
 
                public override ValueGetter<DataViewRowId> GetIdGetter()
                {
                    if (_indices == null)
                    {
                        return
                            (ref DataViewRowId val) =>
                            {
                                Ch.Check(IsGood, RowCursorUtils.FetchValueStateError);
                                val = new DataViewRowId((ulong)Position, 0);
                            };
                    }
                    else
                    {
                        return
                            (ref DataViewRowId val) =>
                            {
                                Ch.Check(IsGood, RowCursorUtils.FetchValueStateError);
                                val = new DataViewRowId((ulong)MappedIndex(), 0);
                            };
                    }
                }
 
                /// <summary>
                /// Returns whether the given column is active in this row.
                /// </summary>
                public override bool IsColumnActive(DataViewSchema.Column column)
                {
                    Ch.Check(column.Index < Schema.Count);
                    return _active[column.Index];
                }
 
                /// <summary>
                /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
                /// This throws if the column is not active in this row, or if the type
                /// <typeparamref name="TValue"/> differs from this column's type.
                /// </summary>
                /// <typeparam name="TValue"> is the column's content type.</typeparam>
                /// <param name="column"> is the output column whose getter should be returned.</param>
                public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
                {
                    Ch.Check(column.Index < Schema.Count);
                    Ch.Check(column.Index < _active.Length && _active[column.Index], "the requested column is not active");
 
                    var originColumnValue = _view._columns[column.Index];
                    var columnValue = originColumnValue as Column<TValue>;
                    if (columnValue == null)
                        throw Ch.Except($"Invalid TValue: '{typeof(TValue)}', " +
                            $"expected type: '{originColumnValue.Type.RawType}'.");
 
                    return
                        (ref TValue value) =>
                        {
                            Ch.Check(IsGood, RowCursorUtils.FetchValueStateError);
                            columnValue.CopyOut(MappedIndex(), ref value);
                        };
                }
 
                protected override bool MoveNextCore()
                {
                    Ch.Assert(Position < _view._rowCount);
                    return 1 < _view._rowCount - Position;
                }
 
                private int MappedIndex()
                {
                    Ch.Assert(IsGood);
                    Ch.Assert(0 <= Position && Position < _view._rowCount);
                    if (_indices == null)
                        return (int)Position;
                    return _indices[(int)Position];
                }
            }
        }
 
        #region Column implementations
 
        private abstract class Column
        {
            public readonly DataViewType Type;
 
            public abstract int Length { get; }
 
            public Column(DataViewType type)
            {
                Type = type;
            }
        }
 
        private abstract class Column<TOut> : Column
        {
            /// <summary>
            /// Produce the output value given the index.
            /// </summary>
            public abstract void CopyOut(int index, ref TOut value);
 
            public Column(DataViewType type)
                : base(type)
            {
                Contracts.Assert(typeof(TOut) == type.RawType);
            }
        }
 
        private abstract class Column<TIn, TOut> : Column<TOut>
        {
            private readonly TIn[] _values;
 
            public override int Length { get { return _values.Length; } }
 
            public Column(DataViewType type, TIn[] values)
                : base(type)
            {
                Contracts.AssertValue(values);
                _values = values;
            }
 
            /// <summary>
            /// Assigns dst in such a way that the caller has ownership of <c>dst</c> without
            /// compromising this object's ownership of <c>src</c>. What that operation will be
            /// will depend on the types.
            /// </summary>
            protected abstract void CopyOut(in TIn src, ref TOut dst);
 
            /// <summary>
            /// Produce the output value given the index. This overload utilizes the <c>CopyOut</c>
            /// helper function.
            /// </summary>
            public override void CopyOut(int index, ref TOut value)
            {
                Contracts.Assert(0 <= index && index < _values.Length);
                CopyOut(in _values[index], ref value);
            }
        }
 
        /// <summary>
        /// A column where the input and output types are the same, and simple assignment does
        /// not compromise ownership of the internal vlaues.
        /// </summary>
        private sealed class AssignmentColumn<T> : Column<T, T>
        {
            public AssignmentColumn(PrimitiveDataViewType type, T[] values)
                : base(type, values)
            {
            }
 
            protected override void CopyOut(in T src, ref T dst)
            {
                dst = src;
            }
        }
 
        /// <summary>
        /// A convenience column for converting strings into textspans.
        /// </summary>
        private sealed class StringToTextColumn : Column<string, ReadOnlyMemory<char>>
        {
            public StringToTextColumn(string[] values)
                : base(TextDataViewType.Instance, values)
            {
            }
 
            protected override void CopyOut(in string src, ref ReadOnlyMemory<char> dst)
            {
                dst = src.AsMemory();
            }
        }
 
        private abstract class VectorColumn<TIn, TOut> : Column<TIn, VBuffer<TOut>>
        {
            public VectorColumn(PrimitiveDataViewType itemType, TIn[] values, Func<TIn, int> lengthFunc)
                : base(InferType(itemType, values, lengthFunc), values)
            {
            }
 
            /// <summary>
            /// A utility function for subclasses that want to get the type with a dimension based
            /// on the input value array and some length function over the input type.
            /// </summary>
            private static DataViewType InferType(PrimitiveDataViewType itemType, TIn[] values, Func<TIn, int> lengthFunc)
            {
                Contracts.AssertValue(itemType);
                Contracts.Assert(itemType.RawType == typeof(TOut));
 
                int degree = 0;
                if (Utils.Size(values) > 0)
                {
                    degree = lengthFunc(values[0]);
                    for (int i = 1; i < values.Length; ++i)
                    {
                        if (degree != lengthFunc(values[i]))
                        {
                            degree = 0;
                            break;
                        }
                    }
                }
                return new VectorDataViewType(itemType, degree);
            }
        }
 
        /// <summary>
        /// A column of buffers.
        /// </summary>
        private sealed class VBufferColumn<T> : VectorColumn<VBuffer<T>, T>
        {
            public VBufferColumn(PrimitiveDataViewType itemType, VBuffer<T>[] values)
                : base(itemType, values, v => v.Length)
            {
            }
 
            protected override void CopyOut(in VBuffer<T> src, ref VBuffer<T> dst)
            {
                src.CopyTo(ref dst);
            }
        }
 
        private sealed class ArrayToVBufferColumn<T> : VectorColumn<T[], T>
        {
            public ArrayToVBufferColumn(PrimitiveDataViewType itemType, T[][] values)
                : base(itemType, values, Utils.Size)
            {
            }
 
            protected override void CopyOut(in T[] src, ref VBuffer<T> dst)
            {
                VBuffer<T>.Copy(src, 0, ref dst, Utils.Size(src));
            }
        }
 
        private sealed class ArrayToSparseVBufferColumn<T> : VectorColumn<T[], T>
        {
            private readonly BufferBuilder<T> _bldr;
 
            public ArrayToSparseVBufferColumn(PrimitiveDataViewType itemType, Combiner<T> combiner, T[][] values)
                : base(itemType, values, Utils.Size)
            {
                _bldr = new BufferBuilder<T>(combiner);
            }
 
            protected override void CopyOut(in T[] src, ref VBuffer<T> dst)
            {
                var length = Utils.Size(src);
                _bldr.Reset(length, false);
                for (int i = 0; i < length; i++)
                    _bldr.AddFeature(i, src[i]);
                _bldr.GetResult(ref dst);
            }
        }
 
        #endregion
    }
}