File: DataView\BatchDataViewMapperBase.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data.DataView
{
    internal abstract class BatchDataViewMapperBase<TInput, TBatch> : IDataView
    {
        public bool CanShuffle => false;
 
        public DataViewSchema Schema => SchemaBindings.AsSchema;
 
        private readonly IDataView _source;
        protected readonly IHost Host;
 
        protected BatchDataViewMapperBase(IHostEnvironment env, string registrationName, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            Host = env.Register(registrationName);
            _source = input;
        }
 
        public long? GetRowCount() => _source.GetRowCount();
 
        public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            Host.CheckValue(columnsNeeded, nameof(columnsNeeded));
            Host.CheckValueOrNull(rand);
 
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, SchemaBindings.AsSchema);
 
            // If we aren't selecting any of the output columns, don't construct our cursor.
            // Note that because we cannot support random due to the inherently
            // stratified nature, neither can we allow the base data to be shuffled,
            // even if it supports shuffling.
            if (!SchemaBindings.AnyNewColumnsActive(predicate))
            {
                var activeInput = SchemaBindings.GetActiveInput(predicate);
                var inputCursor = _source.GetRowCursor(_source.Schema.Where(c => activeInput[c.Index]), null);
                return new BindingsWrappedRowCursor(Host, inputCursor, SchemaBindings);
            }
            var active = SchemaBindings.GetActive(predicate);
            Contracts.Assert(active.Length == SchemaBindings.ColumnCount);
 
            // REVIEW: We can get a different input predicate for the input cursor and for the lookahead cursor. The lookahead
            // cursor is only used for getting the values from the input column, so it only needs that column activated. The
            // other cursor is used to get source columns, so it needs the rest of them activated.
            var predInput = GetSchemaBindingDependencies(predicate);
            var inputCols = _source.Schema.Where(c => predInput(c.Index));
            return new Cursor(this, _source.GetRowCursor(inputCols), _source.GetRowCursor(inputCols), active);
        }
 
        public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            return new[] { GetRowCursor(columnsNeeded, rand) };
        }
 
        protected abstract ColumnBindingsBase SchemaBindings { get; }
        protected abstract TBatch CreateBatch(DataViewRowCursor input);
        protected abstract void ProcessBatch(TBatch currentBatch);
        protected abstract void ProcessExample(TBatch currentBatch, TInput currentInput);
        protected abstract Func<bool> GetLastInBatchDelegate(DataViewRowCursor lookAheadCursor);
        protected abstract Func<bool> GetIsNewBatchDelegate(DataViewRowCursor lookAheadCursor);
        protected abstract ValueGetter<TInput> GetLookAheadGetter(DataViewRowCursor lookAheadCursor);
        protected abstract Delegate[] CreateGetters(DataViewRowCursor input, TBatch currentBatch, bool[] active);
        protected abstract Func<int, bool> GetSchemaBindingDependencies(Func<int, bool> predicate);
 
        private sealed class Cursor : RootCursorBase
        {
            private readonly BatchDataViewMapperBase<TInput, TBatch> _parent;
            private readonly DataViewRowCursor _lookAheadCursor;
            private readonly DataViewRowCursor _input;
 
            private readonly bool[] _active;
            private readonly Delegate[] _getters;
 
            private readonly TBatch _currentBatch;
            private readonly Func<bool> _lastInBatchInLookAheadCursorDel;
            private readonly Func<bool> _firstInBatchInInputCursorDel;
            private readonly ValueGetter<TInput> _inputGetterInLookAheadCursor;
            private TInput _currentInput;
 
            public override long Batch => 0;
 
            public override DataViewSchema Schema => _parent.Schema;
 
            public Cursor(BatchDataViewMapperBase<TInput, TBatch> parent, DataViewRowCursor input, DataViewRowCursor lookAheadCursor, bool[] active)
                : base(parent.Host)
            {
                _parent = parent;
                _input = input;
                _lookAheadCursor = lookAheadCursor;
                _active = active;
 
                _currentBatch = _parent.CreateBatch(_input);
 
                _getters = _parent.CreateGetters(_input, _currentBatch, _active);
 
                _lastInBatchInLookAheadCursorDel = _parent.GetLastInBatchDelegate(_lookAheadCursor);
                _firstInBatchInInputCursorDel = _parent.GetIsNewBatchDelegate(_input);
                _inputGetterInLookAheadCursor = _parent.GetLookAheadGetter(_lookAheadCursor);
            }
 
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                Contracts.CheckParam(IsColumnActive(column), nameof(column), "requested column is not active");
 
                var col = _parent.SchemaBindings.MapColumnIndex(out bool isSrc, column.Index);
                if (isSrc)
                {
                    Contracts.AssertValue(_input);
                    return _input.GetGetter<TValue>(_input.Schema[col]);
                }
 
                Ch.AssertValue(_getters);
                var getter = _getters[col];
                Ch.Assert(getter != null);
                var fn = getter as ValueGetter<TValue>;
                if (fn == null)
                    throw Ch.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}', " +
                            $"expected type: '{getter.GetType().GetGenericArguments().First()}'.");
                return fn;
            }
 
            public override ValueGetter<DataViewRowId> GetIdGetter()
            {
                return
                    (ref DataViewRowId val) =>
                    {
                        Ch.Check(IsGood, "Cannot call ID getter in current state");
                        val = new DataViewRowId((ulong)Position, 0);
                    };
            }
 
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                Ch.Check(column.Index < _parent.SchemaBindings.AsSchema.Count);
                return _active[column.Index];
            }
 
            protected override bool MoveNextCore()
            {
                if (!_input.MoveNext())
                    return false;
                if (!_firstInBatchInInputCursorDel())
                    return true;
 
                // If we are here, this means that _input.MoveNext() has gotten us to the beginning of the next batch,
                // so now we need to look ahead at the entire next batch in the _lookAheadCursor.
                // The _lookAheadCursor's position should be on the last row of the previous batch (or -1).
                Ch.Assert(_lastInBatchInLookAheadCursorDel());
 
                var good = _lookAheadCursor.MoveNext();
                // The two cursors should have the same number of elements, so if _input.MoveNext() returned true,
                // then it must return true here too.
                Ch.Assert(good);
 
                do
                {
                    _inputGetterInLookAheadCursor(ref _currentInput);
                    _parent.ProcessExample(_currentBatch, _currentInput);
                } while (!_lastInBatchInLookAheadCursorDel() && _lookAheadCursor.MoveNext());
 
                _parent.ProcessBatch(_currentBatch);
                return true;
            }
        }
    }
}