File: StatefulFilterTransform.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Transforms
{
    // REVIEW: the current interface to 'state' object may be inadequate: instead of insisting on
    // parameterless constructor, we could take a delegate that would create the state per cursor.
    /// <summary>
    /// This transform is similar to <see cref="CustomMappingTransformer{TSrc,TDst}"/>, but it allows per-cursor state,
    /// as well as the ability to 'accept' or 'filter out' some rows of the supplied <see cref="IDataView"/>.
    /// The downside is that the provided lambda is eagerly called on every row (not lazily when needed), and
    /// parallel cursors are not allowed.
    /// </summary>
    /// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the input <see cref="IDataView"/>.</typeparam>
    /// <typeparam name="TDst">The type that describes what new columns are added by this transform.</typeparam>
    /// <typeparam name="TState">The type that describes per-cursor state.</typeparam>
    internal sealed class StatefulFilterTransform<TSrc, TDst, TState> : IDataView
        where TSrc : class, new()
        where TDst : class, new()
        where TState : class, new()
    {
        private readonly IHost _host;
        private const string RegistrationNameTemplate = "StatefulFilterTransform<{0}, {1}>";
        private readonly IDataView _source;
        private readonly Func<TSrc, TDst, TState, bool> _filterFunc;
        private readonly Action<TState> _initStateAction;
        private readonly ColumnBindings _bindings;
        private readonly InternalSchemaDefinition _addedSchema;
 
        // Memorized input schema definition. Needed for re-apply.
        private readonly SchemaDefinition _inputSchemaDefinition;
        private readonly TypedCursorable<TSrc> _typedSource;
 
        private static string RegistrationName { get { return string.Format(RegistrationNameTemplate, typeof(TSrc).FullName, typeof(TDst).FullName); } }
 
        /// <summary>
        /// Create a filter transform
        /// </summary>
        /// <param name="env">The host environment</param>
        /// <param name="source">The dataview upon which we construct the transform</param>
        /// <param name="filterFunc">The function by which we transform source to destination columns and decide whether
        /// to keep the row.</param>
        /// <param name="initStateAction">The function that is called once per cursor to initialize state. Can be null.</param>
        /// <param name="inputSchemaDefinition">The schema definition overrides for <typeparamref name="TSrc"/></param>
        /// <param name="outputSchemaDefinition">The schema definition overrides for <typeparamref name="TDst"/></param>
        public StatefulFilterTransform(IHostEnvironment env, IDataView source, Func<TSrc, TDst, TState, bool> filterFunc,
            Action<TState> initStateAction,
            SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
        {
            _host = env.Register(RegistrationName);
            _host.AssertValue(source, "source");
            _host.AssertValue(filterFunc, "filterFunc");
            _host.AssertValueOrNull(initStateAction);
            _host.AssertValueOrNull(inputSchemaDefinition);
            _host.AssertValueOrNull(outputSchemaDefinition);
 
            _source = source;
            _filterFunc = filterFunc;
            _initStateAction = initStateAction;
            _inputSchemaDefinition = inputSchemaDefinition;
            _typedSource = TypedCursorable<TSrc>.Create(_host, Source, false, inputSchemaDefinition);
 
            var outSchema = InternalSchemaDefinition.Create(typeof(TDst), outputSchemaDefinition);
            _addedSchema = outSchema;
            _bindings = new ColumnBindings(Source.Schema, DataViewConstructionUtils.GetSchemaColumns(outSchema));
        }
 
        /// <summary>
        /// The 'reapply' constructor.
        /// </summary>
        private StatefulFilterTransform(IHostEnvironment env, StatefulFilterTransform<TSrc, TDst, TState> transform, IDataView newSource)
        {
            _host.AssertValue(transform);
            _host.AssertValue(newSource);
            _source = newSource;
            _filterFunc = transform._filterFunc;
            _typedSource = TypedCursorable<TSrc>.Create(_host, newSource, false, transform._inputSchemaDefinition);
 
            _addedSchema = transform._addedSchema;
            _bindings = new ColumnBindings(newSource.Schema, DataViewConstructionUtils.GetSchemaColumns(_addedSchema));
        }
 
        public bool CanShuffle => false;
 
        DataViewSchema IDataView.Schema => OutputSchema;
 
        public DataViewSchema OutputSchema => _bindings.Schema;
 
        public long? GetRowCount()
        {
            // REVIEW: currently stateful map is implemented via filter, and this is sub-optimal.
            return null;
        }
 
        public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            _host.CheckValueOrNull(rand);
 
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
            var activeInputs = _bindings.GetActiveInput(predicate);
            Func<int, bool> inputPred = c => activeInputs[c];
 
            var input = _typedSource.GetCursor(inputPred, rand == null ? (int?)null : rand.Next());
            return new Cursor(this, input, columnsNeeded);
        }
 
        public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            Contracts.CheckParam(n >= 0, nameof(n));
            Contracts.CheckValueOrNull(rand);
 
            // This transform is stateful, its contract is to allocate exactly one state object per cursor and call the filter function
            // on every row in sequence. Therefore, parallel cursoring is not possible.
            return new[] { GetRowCursor(columnsNeeded, rand) };
        }
 
        public IDataView Source => _source;
 
        private sealed class Cursor : RootCursorBase
        {
            private readonly StatefulFilterTransform<TSrc, TDst, TState> _parent;
 
            private readonly RowCursor<TSrc> _input;
            // This is used to serve getters for the columns we produce.
            private readonly DataViewRow _appendedRow;
 
            private readonly TSrc _src;
            private readonly TDst _dst;
            private readonly TState _state;
 
            private bool _disposed;
 
            public override long Batch => _input.Batch;
 
            public Cursor(StatefulFilterTransform<TSrc, TDst, TState> parent, RowCursor<TSrc> input, IEnumerable<DataViewSchema.Column> columnsNeeded)
                : base(parent._host)
            {
                Ch.AssertValue(input);
                Ch.AssertValue(columnsNeeded);
 
                _parent = parent;
                _input = input;
 
                _src = new TSrc();
                _dst = new TDst();
                _state = new TState();
 
                CursorChannelAttribute.TrySetCursorChannel(_parent._host, _src, Ch);
                CursorChannelAttribute.TrySetCursorChannel(_parent._host, _dst, Ch);
                CursorChannelAttribute.TrySetCursorChannel(_parent._host, _state, Ch);
 
                parent._initStateAction?.Invoke(_state);
 
                var appendedDataView = new DataViewConstructionUtils.SingleRowLoopDataView<TDst>(parent._host, _parent._addedSchema);
                appendedDataView.SetCurrentRowObject(_dst);
 
                var columnNames = columnsNeeded.Select(c => c.Name);
                _appendedRow = appendedDataView.GetRowCursor(appendedDataView.Schema.Where(c => !c.IsHidden && columnNames.Contains(c.Name)));
            }
 
            protected override void Dispose(bool disposing)
            {
                if (_disposed)
                    return;
                if (disposing)
                {
                    if (_state is IDisposable disposableState)
                        disposableState.Dispose();
                    if (_src is IDisposable disposableSrc)
                        disposableSrc.Dispose();
                    if (_dst is IDisposable disposableDst)
                        disposableDst.Dispose();
                    _input.Dispose();
                }
                _disposed = true;
                base.Dispose(disposing);
            }
 
            public override ValueGetter<DataViewRowId> GetIdGetter()
            {
                return _input.GetIdGetter();
            }
 
            protected override bool MoveNextCore()
            {
                bool isAccepted = false;
                while (!isAccepted)
                {
                    if (!_input.MoveNext())
                        return false;
                    RunLambda(out isAccepted);
                }
                return true;
            }
 
            private void RunLambda(out bool isRowAccepted)
            {
                _input.FillValues(_src);
                // REVIEW: what if this throws? Maybe swallow the exception?
                isRowAccepted = _parent._filterFunc(_src, _dst, _state);
            }
 
            public override DataViewSchema Schema => _parent._bindings.Schema;
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                Contracts.CheckParam(column.Index < Schema.Count, nameof(column));
                bool isSrc;
                int iCol = _parent._bindings.MapColumnIndex(out isSrc, column.Index);
                if (isSrc)
                    return _input.IsColumnActive(_input.Schema[iCol]);
                return _appendedRow.IsColumnActive(_appendedRow.Schema[iCol]);
            }
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                Contracts.CheckParam(column.Index < Schema.Count, nameof(column));
                bool isSrc;
                int iCol = _parent._bindings.MapColumnIndex(out isSrc, column.Index);
                return isSrc ?
                    _input.GetGetter<TValue>(_input.Schema[iCol])
                    : _appendedRow.GetGetter<TValue>(_appendedRow.Schema[iCol]);
            }
        }
    }
}