File: CustomMappingFilter.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Transforms
{
    internal abstract class CustomMappingFilterBase<TSrc> : IDataView
        where TSrc : class, new()
    {
        protected readonly IDataView Input;
        protected readonly TypedCursorable<TSrc> TypedSrc;
        protected readonly IHost Host;
 
        public abstract bool CanShuffle { get; }
 
        public DataViewSchema Schema => Input.Schema;
 
        private protected CustomMappingFilterBase(IHostEnvironment env, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            Host = env.Register("CustomFilter");
            Host.CheckValue(input, nameof(input));
 
            Input = input;
            TypedSrc = TypedCursorable<TSrc>.Create(Host, input, false, null);
        }
 
        public long? GetRowCount() => null;
 
        public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            Host.AssertValueOrNull(rand);
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
 
            Func<int, bool> inputPred = TypedSrc.GetDependencies(predicate);
 
            var inputCols = Input.Schema.Where(x => inputPred(x.Index));
            var input = Input.GetRowCursor(inputCols, rand);
            return GetRowCursorCore(input, Utils.BuildArray(Input.Schema.Count, inputCols));
        }
 
        protected abstract DataViewRowCursor GetRowCursorCore(DataViewRowCursor input, bool[] active);
 
        public abstract DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null);
    }
 
    internal sealed class CustomMappingFilter<TSrc> : CustomMappingFilterBase<TSrc>
        where TSrc : class, new()
    {
        private readonly Func<TSrc, bool> _predicate;
 
        public override bool CanShuffle => Input.CanShuffle;
 
        public CustomMappingFilter(IHostEnvironment env, IDataView input, Func<TSrc, bool> predicate)
            : base(env, input)
        {
            Host.CheckValue(predicate, nameof(predicate));
            _predicate = predicate;
        }
 
        protected override DataViewRowCursor GetRowCursorCore(DataViewRowCursor input, bool[] active)
        {
            return new Cursor(this, input, active);
        }
 
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.AssertValueOrNull(rand);
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
 
            Func<int, bool> inputPred = TypedSrc.GetDependencies(predicate);
 
            var inputCols = Input.Schema.Where(x => inputPred(x.Index));
            var inputs = Input.GetRowCursorSet(inputCols, n, rand);
            Host.AssertNonEmpty(inputs);
            var active = Utils.BuildArray(Input.Schema.Count, inputCols);
 
            // No need to split if this is given 1 input cursor.
            var cursors = new DataViewRowCursor[inputs.Length];
            for (int i = 0; i < inputs.Length; i++)
                cursors[i] = new Cursor(this, inputs[i], active);
            return cursors;
        }
 
        private sealed class Cursor : LinkedRowFilterCursorBase
        {
            private readonly Func<bool> _accept;
 
            public Cursor(CustomMappingFilter<TSrc> parent, DataViewRowCursor input, bool[] active)
                : base(parent.Host, input, input.Schema, active)
            {
                Contracts.AssertValue(parent);
 
                IRowReadableAs<TSrc> inputRow = parent.TypedSrc.GetRow(input);
 
                TSrc src = new TSrc();
                long lastServedPosition = -1;
                Action refresh = () =>
                {
                    if (lastServedPosition != input.Position)
                    {
                        inputRow.FillValues(src);
                        lastServedPosition = input.Position;
                    }
                };
 
                var predicate = parent._predicate;
                _accept = () =>
                {
                    refresh();
                    return !predicate(src);
                };
            }
 
            protected override bool Accept()
            {
                return _accept();
            }
        }
    }
 
    internal sealed class StatefulCustomMappingFilter<TSrc, TState> : CustomMappingFilterBase<TSrc>
        where TSrc : class, new()
        where TState : class, new()
    {
        private readonly Func<TSrc, TState, bool> _predicate;
        private readonly Action<TState> _stateInitAction;
 
        public override bool CanShuffle => false;
 
        public StatefulCustomMappingFilter(IHostEnvironment env, IDataView input, Func<TSrc, TState, bool> predicate, Action<TState> stateInitAction)
            : base(env, input)
        {
            Host.CheckValue(predicate, nameof(predicate));
            Host.CheckValue(stateInitAction, nameof(stateInitAction));
 
            _predicate = predicate;
            _stateInitAction = stateInitAction;
        }
 
        protected override DataViewRowCursor GetRowCursorCore(DataViewRowCursor input, bool[] active)
        {
            return new Cursor(this, input, active);
        }
 
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            return new[] { GetRowCursor(columnsNeeded, rand) };
        }
 
        private sealed class Cursor : LinkedRowFilterCursorBase
        {
            private readonly Func<bool> _accept;
 
            public Cursor(StatefulCustomMappingFilter<TSrc, TState> parent, DataViewRowCursor input, bool[] active)
                : base(parent.Host, input, input.Schema, active)
            {
                Contracts.AssertValue(parent);
 
                IRowReadableAs<TSrc> inputRow = parent.TypedSrc.GetRow(input);
 
                TSrc src = new TSrc();
                TState state = new TState();
                parent._stateInitAction(state);
                long lastServedPosition = -1;
                Action refresh = () =>
                {
                    if (lastServedPosition != input.Position)
                    {
                        inputRow.FillValues(src);
                        lastServedPosition = input.Position;
                    }
                };
 
                var predicate = parent._predicate;
                _accept = () =>
                {
                    refresh();
                    return !predicate(src, state);
                };
            }
 
            protected override bool Accept()
            {
                return _accept();
            }
        }
    }
}