File: Transforms\SlotsDroppingTransformer.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), typeof(SlotsDroppingTransformer.Options), typeof(SignatureDataTransform),
    SlotsDroppingTransformer.FriendlyName, SlotsDroppingTransformer.LoaderSignature, "DropSlots")]
 
[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), null, typeof(SignatureLoadDataTransform),
    SlotsDroppingTransformer.FriendlyName, SlotsDroppingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(SlotsDroppingTransformer), null, typeof(SignatureLoadModel),
    SlotsDroppingTransformer.FriendlyName, SlotsDroppingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(SlotsDroppingTransformer), null, typeof(SignatureLoadRowMapper),
   SlotsDroppingTransformer.FriendlyName, SlotsDroppingTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// Transform to drop slots from columns. If the column is scalar, the only slot that can be dropped is slot 0.
    /// If all the slots are to be dropped, a vector valued column will be changed to a vector of length 1 (a scalar column will retain its type) and
    /// the value will be the default value.
    /// </summary>
    [BestFriend]
    internal sealed class SlotsDroppingTransformer : OneToOneTransformerBase
    {
        internal sealed class Options
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to drop the slots for",
                Name = "Column", ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
        }
 
        [BestFriend]
        internal sealed class Column : OneToOneColumn
        {
            [Argument(ArgumentType.Multiple, HelpText = "Source slot index range(s) of the column to drop")]
            public Range[] Slots;
 
            internal static Column Parse(string str)
            {
                Contracts.CheckNonWhiteSpace(str, nameof(str));
 
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            private protected override bool TryParse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                // Allow name:src:slots and src:slots
                int ich = str.LastIndexOf(':');
                if (ich <= 0 || ich >= str.Length - 1)
                    return false;
                if (!base.TryParse(str.Substring(0, ich)))
                    return false;
                return TryParseSlots(str.Substring(ich + 1));
            }
 
            private bool TryParseSlots(string str)
            {
                Contracts.AssertValue(str);
                var strs = str.Split(',');
                if (str.Length == 0)
                    return false;
                Slots = new Range[strs.Length];
                for (int i = 0; i < strs.Length; i++)
                {
                    if ((Slots[i] = Range.Parse(strs[i])) == null)
                        return false;
                }
                return true;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.CheckValue(sb, nameof(sb));
 
                int ich = sb.Length;
 
                if (!TryUnparseCore(sb))
                    return false;
 
                sb.Append(':');
                string pre = "";
                foreach (var src in Slots)
                {
                    sb.Append(pre);
                    if (!src.TryUnparse(sb))
                    {
                        sb.Length = ich;
                        return false;
                    }
                    pre = ",";
                }
                return true;
            }
        }
 
        internal sealed class Range
        {
            [Argument(ArgumentType.Required, HelpText = "First index in the range")]
            public int Min;
 
            // If null, it means int.MaxValue - 1. There are two reasons for this:
            // 1. max is an index, so it has to be strictly less than int.MaxValue.
            // 2. to prevent overflows when adding 1 to it.
            [Argument(ArgumentType.AtMostOnce, HelpText = "Last index in the range")]
            public int? Max;
 
            internal static Range Parse(string str)
            {
                Contracts.CheckNonWhiteSpace(str, nameof(str));
 
                var res = new Range();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            private bool TryParse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                int ich = str.IndexOf('-');
                if (ich < 0)
                {
                    if (!int.TryParse(str, out Min))
                        return false;
                    Max = Min;
                    return true;
                }
 
                if (ich == 0 || ich >= str.Length - 1)
                {
                    return false;
                }
 
                if (!int.TryParse(str.Substring(0, ich), out Min))
                    return false;
 
                string rest = str.Substring(ich + 1);
                if (rest == "*")
                    return true;
 
                int tmp;
                if (!int.TryParse(rest, out tmp))
                    return false;
                Max = tmp;
                return true;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.CheckValue(sb, nameof(sb));
                sb.Append(Min);
                if (Max != null)
                {
                    if (Max != Min)
                        sb.Append("-").Append(Max);
                }
                else
                    sb.Append("-*");
                return true;
            }
 
            /// <summary>
            /// Returns true if the range is valid.
            /// </summary>
            public bool IsValid()
            {
                return Min >= 0 && (Max == null || Min <= Max);
            }
        }
 
        /// <summary>
        /// Describes how the transformer handles one input-output column pair.
        /// </summary>
        [BestFriend]
        internal sealed class ColumnOptions
        {
            public readonly string Name;
            public readonly string InputColumnName;
            public readonly (int min, int? max)[] Slots;
 
            /// <summary>
            /// Describes how the transformer handles one input-output column pair.
            /// </summary>
            /// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
            /// <param name="inputColumnName">Name of the column to transform.
            /// If set to <see langword="null"/>, the value of the <paramref name="name"/> will be used as source.</param>
            /// <param name="slots">Ranges of indices in the input column to be dropped. Setting max in <paramref name="slots"/> to null sets max to int.MaxValue.</param>
            public ColumnOptions(string name, string inputColumnName = null, params (int min, int? max)[] slots)
            {
                Name = name;
                Contracts.CheckValue(Name, nameof(Name));
                InputColumnName = inputColumnName ?? name;
                Contracts.CheckValue(InputColumnName, nameof(InputColumnName));
 
                // By default drop everything.
                Slots = (slots.Length > 0) ? slots : new (int min, int? max)[1];
                foreach (var (min, max) in Slots)
                    Contracts.Assert(min >= 0 && (max == null || min <= max));
            }
 
            internal ColumnOptions(Column column)
            {
                Name = column.Name;
                Contracts.CheckValue(Name, nameof(Name));
                InputColumnName = column.Source ?? column.Name;
                Contracts.CheckValue(InputColumnName, nameof(InputColumnName));
                Slots = column.Slots.Select(range => (range.Min, range.Max)).ToArray();
                foreach (var (min, max) in Slots)
                    Contracts.Assert(min >= 0 && (max == null || min <= max));
            }
        }
 
        private const string RegistrationName = "DropSlots";
        internal const string Summary = "Removes the selected slots from the column.";
        internal const string FriendlyName = "Drop Slots Transform";
        internal const string LoaderSignature = "DropSlotsTransform";
 
        // Store the lower (SlotsMin) and upper (SlotsMax) bounds of ranges of slots to be dropped for each column pair.
        // SlotsMin[i] and SlotsMax[i] are the bounds of the ranges for the i-th column pair.
        internal readonly int[][] SlotsMin;
        internal readonly int[][] SlotsMax;
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "DROPSLOT",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(SlotsDroppingTransformer).Assembly.FullName);
        }
 
        /// <summary>
        /// Initializes a new <see cref="SlotsDroppingTransformer"/> object.
        /// </summary>
        /// <param name="env">The environment to use.</param>
        /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
        /// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
        /// <param name="min">Specifies the lower bound of the range of slots to be dropped. The lower bound is inclusive. </param>
        /// <param name="max">Specifies the upper bound of the range of slots to be dropped. The upper bound is exclusive.</param>
        internal SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int min = default, int? max = null)
            : this(env, new ColumnOptions(outputColumnName, inputColumnName, (min, max)))
        {
        }
 
        /// <summary>
        /// Initializes a new <see cref="SlotsDroppingTransformer"/> object.
        /// </summary>
        /// <param name="env">The environment to use.</param>
        /// <param name="columns">Specifies the ranges of slots to drop for each column pair.</param>
        internal SlotsDroppingTransformer(IHostEnvironment env, params ColumnOptions[] columns)
            : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
        {
            Host.AssertNonEmpty(ColumnPairs);
            GetSlotsMinMax(columns, out SlotsMin, out SlotsMax);
            Host.CheckUserArg(AreRangesValid(SlotsMin, SlotsMax), nameof(columns), "The range min and max must be non-negative and min must be less than or equal to max.");
        }
 
        private SlotsDroppingTransformer(IHostEnvironment env, ModelLoadContext ctx)
            : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), ctx)
        {
            Host.AssertValue(ctx);
            // *** Binary format ***
            // <base>
            // for each added column
            //   int[]: slotsMin
            //   int[]: slotsMax (no count)
 
            Host.AssertNonEmpty(ColumnPairs);
            var size = ColumnPairs.Length;
            SlotsMin = new int[size][];
            SlotsMax = new int[size][];
            for (int i = 0; i < size; i++)
            {
                SlotsMin[i] = ctx.Reader.ReadIntArray();
                Host.CheckDecode(Utils.Size(SlotsMin[i]) > 0);
                SlotsMax[i] = ctx.Reader.ReadIntArray(SlotsMin[i].Length);
            }
            Host.Assert(AreRangesValid(SlotsMin, SlotsMax));
        }
 
        // Factory method for SignatureLoadModel.
        internal static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            ctx.CheckAtModel(GetVersionInfo());
            return new SlotsDroppingTransformer(env, ctx);
        }
 
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
        {
            var columns = options.Columns.Select(column => new ColumnOptions(column)).ToArray();
            return new SlotsDroppingTransformer(env, columns).MakeDataTransform(input);
        }
 
        // Factory method for SignatureLoadDataTransform.
        private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
            => Create(env, ctx).MakeDataTransform(input);
 
        // Factory method for SignatureLoadRowMapper.
        private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
            => Create(env, ctx).MakeRowMapper(inputSchema);
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // <base>
            // for each added column
            //   int[]: slotsMin
            //   int[]: slotsMax (no count)
 
            SaveColumns(ctx);
 
            Host.Assert(AreRangesValid(SlotsMin, SlotsMax));
            for (int i = 0; i < ColumnPairs.Length; i++)
            {
                Host.Assert(SlotsMin[i].Length == SlotsMax[i].Length);
                ctx.Writer.WriteIntArray(SlotsMin[i]);
                ctx.Writer.WriteIntsNoCount(SlotsMax[i]);
            }
        }
 
        private void GetSlotsMinMax(Column col, out int[] slotsMin, out int[] slotsMax)
        {
            slotsMin = new int[col.Slots.Length];
            slotsMax = new int[col.Slots.Length];
            for (int j = 0; j < col.Slots.Length; j++)
            {
                var range = col.Slots[j];
                Host.CheckUserArg(range.IsValid(), nameof(col.Slots), "The range min and max must be non-negative and min must be less than or equal to max.");
                slotsMin[j] = range.Min;
                // There are two reasons for setting the max to int.MaxValue - 1:
                // 1. max is an index, so it has to be strictly less than int.MaxValue.
                // 2. to prevent overflows when adding 1 to it.
                slotsMax[j] = range.Max ?? int.MaxValue - 1;
            }
            Array.Sort(slotsMin, slotsMax);
            var iDst = 0;
            for (int j = 1; j < col.Slots.Length; j++)
            {
                if (slotsMin[j] <= slotsMax[iDst] + 1)
                    slotsMax[iDst] = Math.Max(slotsMax[iDst], slotsMax[j]);
                else
                {
                    iDst++;
                    slotsMin[iDst] = slotsMin[j];
                    slotsMax[iDst] = slotsMax[j];
                }
            }
        }
 
        private static void GetSlotsMinMax(ColumnOptions[] columns, out int[][] slotsMin, out int[][] slotsMax)
        {
            slotsMin = new int[columns.Length][];
            slotsMax = new int[columns.Length][];
            for (int i = 0; i < columns.Length; i++)
            {
                var slots = columns[i].Slots;
                slotsMin[i] = new int[slots.Length];
                slotsMax[i] = new int[slots.Length];
                for (int j = 0; j < slots.Length; j++)
                {
                    var range = slots[j];
                    slotsMin[i][j] = range.min;
                    // There are two reasons for setting the max to int.MaxValue - 1:
                    // 1. max is an index, so it has to be strictly less than int.MaxValue.
                    // 2. to prevent overflows when adding 1 to it.
                    slotsMax[i][j] = range.max ?? int.MaxValue - 1;
                }
                Array.Sort(slotsMin[i], slotsMax[i]);
                var iDst = 0;
                for (int j = 1; j < slots.Length; j++)
                {
                    if (slotsMin[i][j] <= slotsMax[i][iDst] + 1)
                        slotsMax[i][iDst] = Math.Max(slotsMax[i][iDst], slotsMax[i][j]);
                    else
                    {
                        iDst++;
                        slotsMin[i][iDst] = slotsMin[i][j];
                        slotsMax[i][iDst] = slotsMax[i][j];
                    }
                }
                iDst++;
                Array.Resize(ref slotsMin[i], iDst);
                Array.Resize(ref slotsMax[i], iDst);
            }
        }
 
        private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnOptions[] columns)
            => columns.Select(c => (c.Name, c.InputColumnName ?? c.Name)).ToArray();
 
        private static bool AreRangesValid(int[][] slotsMin, int[][] slotsMax)
        {
            if (slotsMin.Length != slotsMax.Length)
                return false;
            for (int iinfo = 0; iinfo < slotsMin.Length; iinfo++)
            {
                var prevmax = -2;
                for (int i = 0; i < slotsMin[iinfo].Length; i++)
                {
                    if (!(0 <= slotsMin[iinfo][i] && slotsMin[iinfo][i] < int.MaxValue))
                        return false;
                    if (!(0 <= slotsMax[iinfo][i] && slotsMax[iinfo][i] < int.MaxValue))
                        return false;
                    if (!(slotsMin[iinfo][i] <= slotsMax[iinfo][i]))
                        return false;
                    if (!(slotsMin[iinfo][i] - 1 > prevmax))
                        return false;
                    prevmax = slotsMax[iinfo][i];
                }
            }
            return true;
        }
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
            => new Mapper(this, schema);
 
        private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
        {
            private static readonly FuncInstanceMethodInfo1<Mapper, Delegate> _makeOneTrivialGetterMethodInfo
                = FuncInstanceMethodInfo1<Mapper, Delegate>.Create(target => target.MakeOneTrivialGetter<int>);
 
            private static readonly FuncInstanceMethodInfo1<Mapper, Delegate> _makeVecTrivialGetterMethodInfo
                = FuncInstanceMethodInfo1<Mapper, Delegate>.Create(target => target.MakeVecTrivialGetter<int>);
 
            private static readonly FuncInstanceMethodInfo1<Mapper, DataViewRow, int, Delegate> _makeVecGetterMethodInfo
                = FuncInstanceMethodInfo1<Mapper, DataViewRow, int, Delegate>.Create(target => target.MakeVecGetter<int>);
 
            private static readonly FuncInstanceMethodInfo1<Mapper, DataViewRow, int, Delegate> _getSrcGetterMethodInfo
                = FuncInstanceMethodInfo1<Mapper, DataViewRow, int, Delegate>.Create(target => target.GetSrcGetter<int>);
 
            private readonly SlotsDroppingTransformer _parent;
            private readonly int[] _cols;
            private readonly DataViewType[] _srcTypes;
            private readonly DataViewType[] _dstTypes;
            private readonly SlotDropper[] _slotDropper;
            // Track if all the slots of the column are to be dropped.
            private readonly bool[] _suppressed;
            private readonly int[][] _categoricalRanges;
 
            public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema)
                : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
            {
                _parent = parent;
                _cols = new int[_parent.ColumnPairs.Length];
                _srcTypes = new DataViewType[_parent.ColumnPairs.Length];
                _dstTypes = new DataViewType[_parent.ColumnPairs.Length];
                _slotDropper = new SlotDropper[_parent.ColumnPairs.Length];
                _suppressed = new bool[_parent.ColumnPairs.Length];
                _categoricalRanges = new int[_parent.ColumnPairs.Length][];
 
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                {
                    if (!InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _cols[i]))
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName);
                    _srcTypes[i] = inputSchema[_cols[i]].Type;
                    VectorDataViewType srcVectorType = _srcTypes[i] as VectorDataViewType;
 
                    var rawType = srcVectorType?.ItemType ?? _srcTypes[i];
                    if (!IsValidColumnType(rawType))
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName);
 
                    int valueCount = srcVectorType?.Size ?? 1;
                    _slotDropper[i] = new SlotDropper(valueCount, _parent.SlotsMin[i], _parent.SlotsMax[i]);
                    ComputeType(inputSchema, i, _slotDropper[i], out _suppressed[i], out _dstTypes[i], out _categoricalRanges[i]);
                }
            }
 
            /// <summary>
            /// Both scalars and vectors are acceptable types, but the item type must have a default value which means it must be
            /// a string, a key, a float or a double.
            /// </summary>
            private static bool IsValidColumnType(DataViewType type)
                => (type is KeyDataViewType keytype && 0 < keytype.Count && keytype.Count < Utils.ArrayMaxSize)
                || type == NumberDataViewType.Single || type == NumberDataViewType.Double || type is TextDataViewType;
 
            /// <summary>
            /// Computes the types (column and slotnames), the length reduction, categorical feature indices
            /// and whether the column is suppressed.
            /// The slotsMin and slotsMax arrays should be sorted and the intervals should not overlap.
            /// </summary>
            /// <param name="input">The input schema</param>
            /// <param name="iinfo">The column index in Infos</param>
            /// <param name="slotDropper">The slots to be dropped.</param>
            /// <param name="suppressed">Whether the column is suppressed (all slots dropped)</param>
            /// <param name="type">The column type</param>
            /// <param name="categoricalRanges">Categorical feature indices.</param>
            private void ComputeType(DataViewSchema input, int iinfo, SlotDropper slotDropper,
                out bool suppressed, out DataViewType type, out int[] categoricalRanges)
            {
                var slotsMin = _parent.SlotsMin[iinfo];
                var slotsMax = _parent.SlotsMax[iinfo];
                Host.AssertValue(slotDropper);
                Host.AssertValue(input);
                Host.AssertNonEmpty(slotsMin);
                Host.AssertNonEmpty(slotsMax);
                Host.Assert(slotsMin.Length == slotsMax.Length);
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
 
                categoricalRanges = null;
                var typeSrc = _srcTypes[iinfo];
                if (!(typeSrc is VectorDataViewType vectorType))
                {
                    type = typeSrc;
                    suppressed = slotsMin.Length > 0 && slotsMin[0] == 0;
                }
                else if (!vectorType.IsKnownSize)
                {
                    type = typeSrc;
                    suppressed = false;
                }
                else
                {
                    Host.Assert(vectorType.IsKnownSize);
                    var dstLength = slotDropper.DstLength;
                    var hasSlotNames = input[_cols[iinfo]].HasSlotNames(vectorType.Size);
                    type = new VectorDataViewType(vectorType.ItemType, Math.Max(dstLength, 1));
                    suppressed = dstLength == 0;
                }
            }
 
            private void GetSlotNames(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
            {
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
 
                var names = default(VBuffer<ReadOnlyMemory<char>>);
                InputSchema[_cols[iinfo]].GetSlotNames(ref names);
                _slotDropper[iinfo].DropSlots(ref names, ref dst);
            }
 
            private void GetCategoricalSlotRanges(int iinfo, ref VBuffer<int> dst)
            {
                if (_categoricalRanges[iinfo] != null)
                {
                    GetCategoricalSlotRangesCore(iinfo, _parent.SlotsMin[iinfo],
                        _parent.SlotsMax[iinfo], _categoricalRanges[iinfo], ref dst);
                }
            }
 
            private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slotsMax, int[] catRanges, ref VBuffer<int> dst)
            {
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                Host.Assert(slotsMax != null && slotsMin != null);
                Host.Assert(slotsMax.Length == slotsMin.Length);
 
                Contracts.Assert(catRanges.Length > 0 && catRanges.Length % 2 == 0);
 
                var ranges = new int[catRanges.Length];
                catRanges.CopyTo(ranges, 0);
                int rangesIndex = 0;
                int dropSlotsIndex = 0;
                int previousDropSlotsIndex = 0;
                int droppedSlotsCount = 0;
                bool combine = false;
                int min = -1;
                int max = -1;
                List<int> newCategoricalSlotRanges = new List<int>();
 
                // Six possible ways a drop slot range interacts with categorical slots range.
                //
                //                    +--------------Drop-------------+
                //                    |                               |
                //
                //                +---Drop---+   +---Drop---+   +---Drop---+
                //  +---Drop---+  |          |   |          |   |          |  +---Drop---+
                //  |          |       |____________Range____________|        |          |
                //
                // The below code is better understood as a state machine.
 
                while (dropSlotsIndex < slotsMin.Length && rangesIndex < ranges.Length)
                {
                    Contracts.Assert(rangesIndex % 2 == 0);
                    Contracts.Assert(ranges[rangesIndex] <= ranges[rangesIndex + 1]);
 
                    if (slotsMax[dropSlotsIndex] < ranges[rangesIndex])
                        dropSlotsIndex++;
                    else if (slotsMin[dropSlotsIndex] > ranges[rangesIndex + 1])
                    {
                        if (combine)
                        {
                            CombineRanges(min, max, ranges[rangesIndex] - droppedSlotsCount,
                                ranges[rangesIndex + 1] - droppedSlotsCount, out min, out max);
                        }
                        else
                        {
                            Contracts.Assert(min == -1 && max == -1);
                            min = ranges[rangesIndex] - droppedSlotsCount;
                            max = ranges[rangesIndex + 1] - droppedSlotsCount;
                        }
 
                        newCategoricalSlotRanges.Add(min);
                        newCategoricalSlotRanges.Add(max);
                        min = max = -1;
                        rangesIndex += 2;
                        combine = false;
                    }
                    else if (slotsMin[dropSlotsIndex] <= ranges[rangesIndex] &&
                             slotsMax[dropSlotsIndex] >= ranges[rangesIndex + 1])
                    {
                        rangesIndex += 2;
                        if (combine)
                        {
                            Contracts.Assert(min >= 0 && min <= max);
                            newCategoricalSlotRanges.Add(min);
                            newCategoricalSlotRanges.Add(max);
                            min = max = -1;
                            combine = false;
                        }
 
                        Contracts.Assert(min == -1 && max == -1);
 
                    }
                    else if (slotsMin[dropSlotsIndex] > ranges[rangesIndex] &&
                             slotsMax[dropSlotsIndex] < ranges[rangesIndex + 1])
                    {
                        if (combine)
                        {
                            CombineRanges(min, max, ranges[rangesIndex] - droppedSlotsCount,
                                slotsMin[dropSlotsIndex] - 1 - droppedSlotsCount, out min, out max);
                        }
                        else
                        {
                            Contracts.Assert(min == -1 && max == -1);
 
                            min = ranges[rangesIndex] - droppedSlotsCount;
                            max = slotsMin[dropSlotsIndex] - 1 - droppedSlotsCount;
                            combine = true;
                        }
 
                        ranges[rangesIndex] = slotsMax[dropSlotsIndex] + 1;
                        dropSlotsIndex++;
                    }
                    else if (slotsMax[dropSlotsIndex] < ranges[rangesIndex + 1])
                    {
                        ranges[rangesIndex] = slotsMax[dropSlotsIndex] + 1;
                        dropSlotsIndex++;
                    }
                    else
                        ranges[rangesIndex + 1] = slotsMin[dropSlotsIndex] - 1;
 
                    if (previousDropSlotsIndex < dropSlotsIndex)
                    {
                        Contracts.Assert(dropSlotsIndex - previousDropSlotsIndex == 1);
                        droppedSlotsCount += slotsMax[previousDropSlotsIndex] - slotsMin[previousDropSlotsIndex] + 1;
                        previousDropSlotsIndex = dropSlotsIndex;
                    }
                }
 
                Contracts.Assert(rangesIndex % 2 == 0);
 
                if (combine)
                {
                    Contracts.Assert(rangesIndex < ranges.Length - 1);
                    CombineRanges(min, max, ranges[rangesIndex] - droppedSlotsCount,
                        ranges[rangesIndex + 1] - droppedSlotsCount, out min, out max);
 
                    newCategoricalSlotRanges.Add(min);
                    newCategoricalSlotRanges.Add(max);
                    rangesIndex += 2;
                    combine = false;
                    min = max = -1;
                }
 
                Contracts.Assert(min == -1 && max == -1);
 
                for (int i = rangesIndex; i < ranges.Length; i++)
                    newCategoricalSlotRanges.Add(ranges[i] - droppedSlotsCount);
 
                Contracts.Assert(newCategoricalSlotRanges.Count % 2 == 0);
                Contracts.Assert(newCategoricalSlotRanges.TrueForAll(x => x >= 0));
                Contracts.Assert(0 <= droppedSlotsCount && droppedSlotsCount <= slotsMax[slotsMax.Length - 1] + 1);
 
                if (newCategoricalSlotRanges.Count > 0)
                    dst = new VBuffer<int>(newCategoricalSlotRanges.Count, newCategoricalSlotRanges.ToArray());
            }
 
            private void CombineRanges(
                int minRange1, int maxRange1, int minRange2, int maxRange2,
                out int newRangeMin, out int newRangeMax)
            {
                Contracts.Assert(minRange2 >= 0 && maxRange2 >= 0);
                Contracts.Assert(minRange2 <= maxRange2);
                Contracts.Assert(minRange1 >= 0 && maxRange1 >= 0);
                Contracts.Assert(minRange1 <= maxRange1);
                Contracts.Assert(maxRange1 + 1 == minRange2);
 
                newRangeMin = minRange1;
                newRangeMax = maxRange2;
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
            {
                Host.AssertValue(input);
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                disposer = null;
 
                var typeSrc = _srcTypes[iinfo];
 
                if (!(typeSrc is VectorDataViewType))
                {
                    if (_suppressed[iinfo])
                        return MakeOneTrivialGetter(input, iinfo);
                    return GetSrcGetter(typeSrc, input, _cols[iinfo]);
                }
                if (_suppressed[iinfo])
                    return MakeVecTrivialGetter(input, iinfo);
                return MakeVecGetter(input, iinfo);
            }
 
            private Delegate MakeOneTrivialGetter(DataViewRow input, int iinfo)
            {
                Host.AssertValue(input);
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                Host.Assert(!(_srcTypes[iinfo] is VectorDataViewType));
                Host.Assert(_suppressed[iinfo]);
 
                return Utils.MarshalInvoke(_makeOneTrivialGetterMethodInfo, this, _srcTypes[iinfo].RawType);
            }
 
            private ValueGetter<TDst> MakeOneTrivialGetter<TDst>()
            {
                return OneTrivialGetter;
            }
 
            // Delegates onto instance methods are more efficient than delegates onto static methods.
            private void OneTrivialGetter<TDst>(ref TDst value)
            {
                value = default(TDst);
            }
 
            private Delegate MakeVecTrivialGetter(DataViewRow input, int iinfo)
            {
                Host.AssertValue(input);
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                VectorDataViewType vectorType = (VectorDataViewType)_srcTypes[iinfo];
                Host.Assert(_suppressed[iinfo]);
 
                return Utils.MarshalInvoke(_makeVecTrivialGetterMethodInfo, this, vectorType.ItemType.RawType);
            }
 
            private ValueGetter<VBuffer<TDst>> MakeVecTrivialGetter<TDst>()
            {
                return VecTrivialGetter;
            }
 
            // Delegates onto instance methods are more efficient than delegates onto static methods.
            private void VecTrivialGetter<TDst>(ref VBuffer<TDst> value)
            {
                VBufferUtils.Resize(ref value, 1, 0);
            }
 
            private Delegate MakeVecGetter(DataViewRow input, int iinfo)
            {
                Host.AssertValue(input);
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                VectorDataViewType vectorType = (VectorDataViewType)_srcTypes[iinfo];
                Host.Assert(!_suppressed[iinfo]);
 
                return Utils.MarshalInvoke(_makeVecGetterMethodInfo, this, vectorType.ItemType.RawType, input, iinfo);
            }
 
            private ValueGetter<VBuffer<TDst>> MakeVecGetter<TDst>(DataViewRow input, int iinfo)
            {
                var srcGetter = GetSrcGetter<VBuffer<TDst>>(input, iinfo);
                var typeDst = _dstTypes[iinfo];
                int srcValueCount = _srcTypes[iinfo].GetValueCount();
                if (typeDst is VectorDataViewType dstVector && dstVector.IsKnownSize && dstVector.Size == srcValueCount)
                    return srcGetter;
 
                var buffer = default(VBuffer<TDst>);
                return
                    (ref VBuffer<TDst> value) =>
                    {
                        srcGetter(ref buffer);
                        _slotDropper[iinfo].DropSlots(ref buffer, ref value);
                    };
            }
 
            private ValueGetter<T> GetSrcGetter<T>(DataViewRow input, int iinfo)
            {
                Host.AssertValue(input);
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                var src = input.Schema[_cols[iinfo]];
                Host.Assert(input.IsColumnActive(src));
                return input.GetGetter<T>(src);
            }
 
            private Delegate GetSrcGetter(DataViewType typeDst, DataViewRow row, int iinfo)
            {
                Host.CheckValue(typeDst, nameof(typeDst));
                Host.CheckValue(row, nameof(row));
 
                return Utils.MarshalInvoke(_getSrcGetterMethodInfo, this, typeDst.RawType, row, iinfo);
            }
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
                var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                {
                    // Avoid closure when adding metadata.
                    int iinfo = i;
 
                    InputSchema.TryGetColumnIndex(_parent.ColumnPairs[iinfo].inputColumnName, out int colIndex);
                    Host.Assert(colIndex >= 0);
                    var builder = new DataViewSchema.Annotations.Builder();
 
                    // Add SlotNames metadata.
                    if (_srcTypes[iinfo] is VectorDataViewType vectorType && vectorType.IsKnownSize)
                    {
                        var dstLength = _slotDropper[iinfo].DstLength;
                        var hasSlotNames = InputSchema[_cols[iinfo]].HasSlotNames(vectorType.Size);
                        var type = new VectorDataViewType(vectorType.ItemType, Math.Max(dstLength, 1));
 
                        if (hasSlotNames && dstLength > 0)
                        {
                            // Add slot name metadata.
                            ValueGetter<VBuffer<ReadOnlyMemory<char>>> slotNamesGetter = (ref VBuffer<ReadOnlyMemory<char>> dst) => GetSlotNames(iinfo, ref dst);
                            builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, dstLength), slotNamesGetter);
                        }
                    }
 
                    // Add CategoricalSlotRanges metadata.
                    if (!_suppressed[iinfo])
                    {
                        if (AnnotationUtils.TryGetCategoricalFeatureIndices(InputSchema, _cols[iinfo], out _categoricalRanges[iinfo]))
                        {
                            VBuffer<int> dst = default(VBuffer<int>);
                            GetCategoricalSlotRangesCore(iinfo, _slotDropper[iinfo].SlotsMin, _slotDropper[iinfo].SlotsMax, _categoricalRanges[iinfo], ref dst);
                            // REVIEW: cache dst as opposed to calculating it again.
                            if (dst.Length > 0)
                            {
                                Contracts.Assert(dst.Length % 2 == 0);
                                // Add slot name metadata.
                                ValueGetter<VBuffer<int>> categoricalSlotRangesGetter = (ref VBuffer<int> dest) => GetCategoricalSlotRanges(iinfo, ref dest);
                                builder.Add(AnnotationUtils.Kinds.CategoricalSlotRanges, AnnotationUtils.GetCategoricalType(dst.Length / 2), categoricalSlotRangesGetter);
                            }
                        }
                    }
 
                    // Add isNormalize and KeyValues metadata.
                    builder.Add(InputSchema[_cols[iinfo]].Annotations, x => x == AnnotationUtils.Kinds.KeyValues || x == AnnotationUtils.Kinds.IsNormalized);
 
                    result[iinfo] = new DataViewSchema.DetachedColumn(_parent.ColumnPairs[iinfo].outputColumnName, _dstTypes[iinfo], builder.ToAnnotations());
                }
                return result;
            }
 
            public bool CanSaveOnnx(OnnxContext ctx) => true;
 
            public void SaveAsOnnx(OnnxContext ctx)
            {
                Host.CheckValue(ctx, nameof(ctx));
 
                for (int iinfo = 0; iinfo < _cols.Length; ++iinfo)
                {
                    string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
                    if (!ctx.ContainsColumn(inputColumnName))
                        continue;
 
                    string srcVariableName = ctx.GetVariableName(inputColumnName);
                    string dstVariableName = ctx.AddIntermediateVariable(_dstTypes[iinfo], _parent.ColumnPairs[iinfo].outputColumnName);
                    if (!SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName))
                        ctx.RemoveColumn(dstVariableName);
                }
            }
 
            public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
            {
                const int minimumOpSetVersion = 9;
                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
 
                string opType;
                var slots = _slotDropper[iinfo].GetPreservedSlots();
                // vector column is not suppressed
                if (slots.Count() > 0)
                {
                    opType = "GatherElements";
                    var slotsVar = ctx.AddInitializer(slots, new long[] { 1, slots.Count() }, "PreservedSlots");
                    var node = ctx.CreateNode(opType, new[] { srcVariableName, slotsVar }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
                    node.AddAttribute("axis", 1);
                }
                // When the vector/scalar columnn is suppressed, we simply create an empty output vector
                else
                {
                    string constVal;
                    var type = _srcTypes[iinfo].GetItemType();
                    if (type == TextDataViewType.Instance)
                        constVal = ctx.AddInitializer(new string[] { "" }, new long[] { 1, 1 });
                    else if (type == NumberDataViewType.Single)
                        constVal = ctx.AddInitializer(new float[] { 0 }, new long[] { 1, 1 });
                    else
                        constVal = ctx.AddInitializer(new double[] { 0 }, new long[] { 1, 1 });
 
                    opType = "Identity";
                    ctx.CreateNode(opType, constVal, dstVariableName, ctx.GetNodeName(opType), "");
                }
                return true;
            }
 
        }
    }
}