File: Transforms\NAFilter.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
// REVIEW: As soon as we stop writing sizeof(Float), or when we retire the double builds, we can remove this.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(NAFilter.Summary, typeof(NAFilter), typeof(NAFilter.Arguments), typeof(SignatureDataTransform),
    NAFilter.FriendlyName, NAFilter.ShortName, "MissingValueFilter", "MissingFilter")]
 
// REVIEW: Make sure that the "MissingFeatureFilter" signature is maintained for backwards compatibility,
// and this is not a bug.
[assembly: LoadableClass(NAFilter.Summary, typeof(NAFilter), null, typeof(SignatureLoadDataTransform),
    NAFilter.FriendlyName, NAFilter.LoaderSignature, "MissingFeatureFilter")]
 
namespace Microsoft.ML.Transforms
{
    /// <include file='doc.xml' path='doc/members/member[@name="NAFilter"]'/>
    [BestFriend]
    internal sealed class NAFilter : FilterBase
    {
        private static class Defaults
        {
            public const bool Complement = false;
        }
 
        public sealed class Arguments : TransformInputBase
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Column", Name = "Column", ShortName = "col", SortOrder = 1)]
            public string[] Columns;
 
            [Argument(ArgumentType.Multiple, HelpText = "If true, keep only rows that contain NA values, and filter the rest.")]
            public bool Complement = Defaults.Complement;
        }
 
        private sealed class ColInfo
        {
            public readonly int Index;
            public readonly DataViewType Type;
 
            public ColInfo(int index, DataViewType type)
            {
                Index = index;
                Type = type;
            }
        }
 
        public const string Summary = "Filters out rows that contain missing values.";
        public const string FriendlyName = "NA Filter";
        public const string ShortName = "NAFilter";
 
        public const string LoaderSignature = "MissingValueFilter";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "MISFETFL",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                // This is an older name and can be removed once we don't care about old code
                // being able to load this.
                loaderSignatureAlt: "MissingFeatureFilter",
                loaderAssemblyName: typeof(NAFilter).Assembly.FullName);
        }
 
        private readonly ColInfo[] _infos;
        private readonly Dictionary<int, int> _srcIndexToInfoIndex;
        private readonly bool _complement;
        private const string RegistrationName = "MissingValueFilter";
 
        /// <summary>
        /// Initializes a new instance of <see cref="NAFilter"/>.
        /// </summary>
        /// <param name="env">Host Environment.</param>
        /// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
        /// <param name="complement">If true, keep only rows that contain NA values, and filter the rest.</param>
        /// <param name="columns">Name of the columns. Only these columns will be used to filter rows having 'NA' values.</param>
        public NAFilter(IHostEnvironment env, IDataView input, bool complement = Defaults.Complement, params string[] columns)
            : this(env, new Arguments() { Columns = columns, Complement = complement }, input)
        {
        }
 
        public NAFilter(IHostEnvironment env, Arguments args, IDataView input)
            : base(env, RegistrationName, input)
        {
            Host.CheckValue(args, nameof(args));
            Host.CheckValue(input, nameof(input));
            Host.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns));
            Host.CheckValue(env, nameof(env));
 
            _infos = new ColInfo[args.Columns.Length];
            _srcIndexToInfoIndex = new Dictionary<int, int>(_infos.Length);
            _complement = args.Complement;
            var schema = Source.Schema;
            for (int i = 0; i < _infos.Length; i++)
            {
                string src = args.Columns[i];
                int index;
                if (!schema.TryGetColumnIndex(src, out index))
                    throw Host.ExceptUserArg(nameof(args.Columns), "Source column '{0}' not found", src);
                if (_srcIndexToInfoIndex.ContainsKey(index))
                    throw Host.ExceptUserArg(nameof(args.Columns), "Source column '{0}' specified multiple times", src);
 
                var type = schema[index].Type;
                if (!TestType(type))
                    throw Host.ExceptUserArg(nameof(args.Columns), $"Column '{src}' has type {type} which does not support missing values, so we cannot filter on them", src);
 
                _infos[i] = new ColInfo(index, type);
                _srcIndexToInfoIndex.Add(index, i);
            }
        }
 
        public NAFilter(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of columns
            // int[]: ids of column names
            int cbFloat = ctx.Reader.ReadInt32();
            Host.CheckDecode(cbFloat == sizeof(Single) || cbFloat == sizeof(Double));
            int cinfo = ctx.Reader.ReadInt32();
            Host.CheckDecode(cinfo > 0);
 
            _infos = new ColInfo[cinfo];
            _srcIndexToInfoIndex = new Dictionary<int, int>(_infos.Length);
            var schema = Source.Schema;
            for (int i = 0; i < cinfo; i++)
            {
                string src = ctx.LoadNonEmptyString();
                int index;
                if (!schema.TryGetColumnIndex(src, out index))
                    throw Host.ExceptSchemaMismatch(nameof(schema), "source", src);
                if (_srcIndexToInfoIndex.ContainsKey(index))
                    throw Host.Except("Source column '{0}' specified multiple times", src);
 
                var type = schema[index].Type;
                if (!TestType(type))
                    throw Host.ExceptSchemaMismatch(nameof(schema), "source", src, "scalar or vector of float, double or KeyType", type.ToString());
 
                _infos[i] = new ColInfo(index, type);
                _srcIndexToInfoIndex.Add(index, i);
            }
        }
 
        public static NAFilter Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(ctx, nameof(ctx));
            h.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());
            return h.Apply("Loading Model", ch => new NAFilter(h, ctx, input));
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of columns
            // int[]: ids of column names
            ctx.Writer.Write(sizeof(float));
            Host.Assert(_infos.Length > 0);
            ctx.Writer.Write(_infos.Length);
            foreach (var info in _infos)
                ctx.SaveNonEmptyString(Source.Schema[info.Index].Name);
        }
 
        private static bool TestType(DataViewType type)
        {
            Contracts.AssertValue(type);
 
            var itemType = (type as VectorDataViewType)?.ItemType ?? type;
            if (itemType == NumberDataViewType.Single)
                return true;
            if (itemType == NumberDataViewType.Double)
                return true;
            if (itemType is KeyDataViewType)
                return true;
            return false;
        }
 
        protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
        {
            Host.AssertValue(predicate);
            // This transform has no preference.
            return null;
        }
 
        protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            Host.AssertValueOrNull(rand);
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
 
            Func<int, bool> inputPred = GetActive(predicate, out bool[] active);
 
            var inputCols = Source.Schema.Where(x => inputPred(x.Index));
            var input = Source.GetRowCursor(inputCols, rand);
            return new Cursor(this, input, active);
        }
 
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.CheckValueOrNull(rand);
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
 
            Func<int, bool> inputPred = GetActive(predicate, out bool[] active);
 
            var inputCols = Source.Schema.Where(x => inputPred(x.Index));
            var inputs = Source.GetRowCursorSet(inputCols, n, rand);
            Host.AssertNonEmpty(inputs);
 
            // No need to split if this is given 1 input cursor.
            var cursors = new DataViewRowCursor[inputs.Length];
            for (int i = 0; i < inputs.Length; i++)
                cursors[i] = new Cursor(this, inputs[i], active);
            return cursors;
        }
 
        private Func<int, bool> GetActive(Func<int, bool> predicate, out bool[] active)
        {
            Host.AssertValue(predicate);
            active = new bool[Source.Schema.Count];
            bool[] activeInput = new bool[Source.Schema.Count];
            for (int i = 0; i < active.Length; i++)
                activeInput[i] = active[i] = predicate(i);
            for (int i = 0; i < _infos.Length; i++)
                activeInput[_infos[i].Index] = true;
            return col => activeInput[col];
        }
 
        private sealed class Cursor : LinkedRowFilterCursorBase
        {
            private abstract class Value
            {
                private static readonly FuncStaticMethodInfo1<Cursor, ColInfo, Value> _createOneMethodInfo
                    = new FuncStaticMethodInfo1<Cursor, ColInfo, Value>(CreateOne<int>);
 
                private static readonly FuncStaticMethodInfo1<Cursor, ColInfo, Value> _createVecMethodInfo
                    = new FuncStaticMethodInfo1<Cursor, ColInfo, Value>(CreateVec<int>);
 
                protected readonly Cursor Cursor;
 
                protected Value(Cursor cursor)
                {
                    Contracts.AssertValue(cursor);
                    Cursor = cursor;
                }
 
                public abstract bool Refresh();
 
                public abstract Delegate GetGetter();
 
                public static Value Create(Cursor cursor, ColInfo info)
                {
                    Contracts.AssertValue(cursor);
                    Contracts.AssertValue(info);
 
                    FuncStaticMethodInfo1<Cursor, ColInfo, Value> method;
                    Type genericArgument;
                    if (info.Type is VectorDataViewType vecType)
                    {
                        method = _createVecMethodInfo;
                        genericArgument = vecType.ItemType.RawType;
                    }
                    else
                    {
                        method = _createOneMethodInfo;
                        genericArgument = info.Type.RawType;
                    }
 
                    return Utils.MarshalInvoke(method, genericArgument, cursor, info);
                }
 
                private static ValueOne<T> CreateOne<T>(Cursor cursor, ColInfo info)
                {
                    Contracts.AssertValue(cursor);
                    Contracts.AssertValue(info);
                    Contracts.Assert(!(info.Type is VectorDataViewType));
                    Contracts.Assert(info.Type.RawType == typeof(T));
 
                    var getSrc = cursor.Input.GetGetter<T>(cursor.Input.Schema[info.Index]);
                    var hasBad = Data.Conversion.Conversions.DefaultInstance.GetIsNAPredicate<T>(info.Type);
                    return new ValueOne<T>(cursor, getSrc, hasBad);
                }
 
                private static ValueVec<T> CreateVec<T>(Cursor cursor, ColInfo info)
                {
                    Contracts.AssertValue(cursor);
                    Contracts.AssertValue(info);
                    Contracts.Assert(info.Type is VectorDataViewType);
                    Contracts.Assert(info.Type.RawType == typeof(VBuffer<T>));
 
                    var getSrc = cursor.Input.GetGetter<VBuffer<T>>(cursor.Input.Schema[info.Index]);
                    var hasBad = Data.Conversion.Conversions.DefaultInstance.GetHasMissingPredicate<T>((VectorDataViewType)info.Type);
                    return new ValueVec<T>(cursor, getSrc, hasBad);
                }
 
                private abstract class TypedValue<T> : Value
                {
                    private readonly ValueGetter<T> _getSrc;
                    private readonly InPredicate<T> _hasBad;
                    public T Src;
 
                    protected TypedValue(Cursor cursor, ValueGetter<T> getSrc, InPredicate<T> hasBad)
                        : base(cursor)
                    {
                        Contracts.AssertValue(getSrc);
                        Contracts.AssertValue(hasBad);
                        _getSrc = getSrc;
                        _hasBad = hasBad;
                    }
 
                    public override bool Refresh()
                    {
                        _getSrc(ref Src);
                        return !_hasBad(in Src);
                    }
                }
 
                private sealed class ValueOne<T> : TypedValue<T>
                {
                    private readonly ValueGetter<T> _getter;
 
                    public ValueOne(Cursor cursor, ValueGetter<T> getSrc, InPredicate<T> hasBad)
                        : base(cursor, getSrc, hasBad)
                    {
                        _getter = GetValue;
                    }
 
                    public void GetValue(ref T dst)
                    {
                        Contracts.Check(Cursor.IsGood);
                        dst = Src;
                    }
 
                    public override Delegate GetGetter()
                    {
                        return _getter;
                    }
                }
 
                private sealed class ValueVec<T> : TypedValue<VBuffer<T>>
                {
                    private readonly ValueGetter<VBuffer<T>> _getter;
 
                    public ValueVec(Cursor cursor, ValueGetter<VBuffer<T>> getSrc, InPredicate<VBuffer<T>> hasBad)
                        : base(cursor, getSrc, hasBad)
                    {
                        _getter = GetValue;
                    }
 
                    public void GetValue(ref VBuffer<T> dst)
                    {
                        Contracts.Check(Cursor.IsGood);
                        Src.CopyTo(ref dst);
                    }
 
                    public override Delegate GetGetter()
                    {
                        return _getter;
                    }
                }
            }
 
            private readonly NAFilter _parent;
            private readonly Value[] _values;
 
            public Cursor(NAFilter parent, DataViewRowCursor input, bool[] active)
                : base(parent.Host, input, parent.OutputSchema, active)
            {
                _parent = parent;
                _values = new Value[_parent._infos.Length];
                for (int i = 0; i < _parent._infos.Length; i++)
                    _values[i] = Value.Create(this, _parent._infos[i]);
            }
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                Ch.Check(IsColumnActive(column));
 
                ValueGetter<TValue> fn;
                if (TryGetColumnValueGetter(column.Index, out fn))
                    return fn;
                return Input.GetGetter<TValue>(column);
            }
 
            /// <summary>
            /// Gets the appropriate column value getter for a mapped column. If the column
            /// is not mapped, this returns false with the out parameters getting default values.
            /// If the column is mapped but the TValue is of the wrong type, an exception is
            /// thrown.
            /// </summary>
            private bool TryGetColumnValueGetter<TValue>(int col, out ValueGetter<TValue> fn)
            {
                Ch.Assert(IsColumnActive(Schema[col]));
 
                int index;
                if (!_parent._srcIndexToInfoIndex.TryGetValue(col, out index))
                {
                    fn = null;
                    return false;
                }
 
                var originFn = _values[index].GetGetter();
                fn = originFn as ValueGetter<TValue>;
                if (fn == null)
                    throw Ch.Except($"Invalid TValue: '{typeof(TValue)}', " +
                            $"expected type: '{originFn.GetType().GetGenericArguments().First()}'.");
                return true;
            }
 
            protected override bool Accept()
            {
                for (int i = 0; i < _parent._infos.Length; i++)
                {
                    if (!_values[i].Refresh())
                        return _parent._complement;
                }
                return !_parent._complement;
            }
        }
    }
}