File: CustomMappingFilter.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Transforms
    internal abstract class CustomMappingFilterBase<TSrc> : IDataView
        where TSrc : class, new()
        protected readonly IDataView Input;
        protected readonly TypedCursorable<TSrc> TypedSrc;
        protected readonly IHost Host;
        public abstract bool CanShuffle { get; }
        public DataViewSchema Schema => Input.Schema;
        private protected CustomMappingFilterBase(IHostEnvironment env, IDataView input)
            Contracts.CheckValue(env, nameof(env));
            Host = env.Register("CustomFilter");
            Host.CheckValue(input, nameof(input));
            Input = input;
            TypedSrc = TypedCursorable<TSrc>.Create(Host, input, false, null);
        public long? GetRowCount() => null;
        public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
            Func<int, bool> inputPred = TypedSrc.GetDependencies(predicate);
            var inputCols = Input.Schema.Where(x => inputPred(x.Index));
            var input = Input.GetRowCursor(inputCols, rand);
            return GetRowCursorCore(input, Utils.BuildArray(Input.Schema.Count, inputCols));
        protected abstract DataViewRowCursor GetRowCursorCore(DataViewRowCursor input, bool[] active);
        public abstract DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null);
    internal sealed class CustomMappingFilter<TSrc> : CustomMappingFilterBase<TSrc>
        where TSrc : class, new()
        private readonly Func<TSrc, bool> _predicate;
        public override bool CanShuffle => Input.CanShuffle;
        public CustomMappingFilter(IHostEnvironment env, IDataView input, Func<TSrc, bool> predicate)
            : base(env, input)
            Host.CheckValue(predicate, nameof(predicate));
            _predicate = predicate;
        protected override DataViewRowCursor GetRowCursorCore(DataViewRowCursor input, bool[] active)
            return new Cursor(this, input, active);
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
            Func<int, bool> inputPred = TypedSrc.GetDependencies(predicate);
            var inputCols = Input.Schema.Where(x => inputPred(x.Index));
            var inputs = Input.GetRowCursorSet(inputCols, n, rand);
            var active = Utils.BuildArray(Input.Schema.Count, inputCols);
            // No need to split if this is given 1 input cursor.
            var cursors = new DataViewRowCursor[inputs.Length];
            for (int i = 0; i < inputs.Length; i++)
                cursors[i] = new Cursor(this, inputs[i], active);
            return cursors;
        private sealed class Cursor : LinkedRowFilterCursorBase
            private readonly Func<bool> _accept;
            public Cursor(CustomMappingFilter<TSrc> parent, DataViewRowCursor input, bool[] active)
                : base(parent.Host, input, input.Schema, active)
                IRowReadableAs<TSrc> inputRow = parent.TypedSrc.GetRow(input);
                TSrc src = new TSrc();
                long lastServedPosition = -1;
                Action refresh = () =>
                    if (lastServedPosition != input.Position)
                        lastServedPosition = input.Position;
                var predicate = parent._predicate;
                _accept = () =>
                    return !predicate(src);
            protected override bool Accept()
                return _accept();
    internal sealed class StatefulCustomMappingFilter<TSrc, TState> : CustomMappingFilterBase<TSrc>
        where TSrc : class, new()
        where TState : class, new()
        private readonly Func<TSrc, TState, bool> _predicate;
        private readonly Action<TState> _stateInitAction;
        public override bool CanShuffle => false;
        public StatefulCustomMappingFilter(IHostEnvironment env, IDataView input, Func<TSrc, TState, bool> predicate, Action<TState> stateInitAction)
            : base(env, input)
            Host.CheckValue(predicate, nameof(predicate));
            Host.CheckValue(stateInitAction, nameof(stateInitAction));
            _predicate = predicate;
            _stateInitAction = stateInitAction;
        protected override DataViewRowCursor GetRowCursorCore(DataViewRowCursor input, bool[] active)
            return new Cursor(this, input, active);
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
            return new[] { GetRowCursor(columnsNeeded, rand) };
        private sealed class Cursor : LinkedRowFilterCursorBase
            private readonly Func<bool> _accept;
            public Cursor(StatefulCustomMappingFilter<TSrc, TState> parent, DataViewRowCursor input, bool[] active)
                : base(parent.Host, input, input.Schema, active)
                IRowReadableAs<TSrc> inputRow = parent.TypedSrc.GetRow(input);
                TSrc src = new TSrc();
                TState state = new TState();
                long lastServedPosition = -1;
                Action refresh = () =>
                    if (lastServedPosition != input.Position)
                        lastServedPosition = input.Position;
                var predicate = parent._predicate;
                _accept = () =>
                    return !predicate(src, state);
            protected override bool Accept()
                return _accept();