File: Transforms\ColumnSelecting.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(ColumnSelectingTransformer.Summary, typeof(IDataTransform), typeof(ColumnSelectingTransformer),
                typeof(ColumnSelectingTransformer.Options), typeof(SignatureDataTransform),
                ColumnSelectingTransformer.UserName, "SelectColumns", "SelectColumnsTransform", ColumnSelectingTransformer.ShortName, DocName = "transform/SelectTransforms.md")]
 
[assembly: LoadableClass(ColumnSelectingTransformer.Summary, typeof(IDataView), typeof(ColumnSelectingTransformer), null, typeof(SignatureLoadDataTransform),
                            ColumnSelectingTransformer.UserName, ColumnSelectingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(ColumnSelectingTransformer.Summary, typeof(ColumnSelectingTransformer), null, typeof(SignatureLoadModel),
                            ColumnSelectingTransformer.UserName, ColumnSelectingTransformer.LoaderSignature)]
 
// Back-compat to handle loading of the Drop and Keep Transformer
[assembly: LoadableClass("", typeof(IDataView), typeof(ColumnSelectingTransformer), null, typeof(SignatureLoadDataTransform),
    "", ColumnSelectingTransformer.DropLoaderSignature)]
 
// Back-compat to handle loading of the Choose Columns Transformer
[assembly: LoadableClass("", typeof(IDataView), typeof(ColumnSelectingTransformer), null, typeof(SignatureLoadDataTransform),
    "", ColumnSelectingTransformer.ChooseLoaderSignature, ColumnSelectingTransformer.ChooseLoaderSignatureOld)]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// Keeps or drops selected columns from an <see cref="IDataView"/>.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    ///
    /// ###  Estimator Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Does this estimator need to look at the data to train its parameters? | No |
    /// | Input columns data type | Any |
    /// | Exportable to ONNX | Yes |
    ///
    /// The resulting <xref:Microsoft.ML.Transforms.ColumnSelectingTransformer>
    /// operates on the schema of a given <xref:Microsoft.ML.IDataView> by dropping or keeping selected columns from the schema.
    ///
    /// It is commonly used to remove unwanted columns before serializing a dataset or writing it to a file.
    /// It is not necessary to drop unused columns before training or performing transforms,
    /// as the <xref:Microsoft.ML.IDataView> is lazily evaluated and will not actually materialize the columns until needed.
    /// In the case of serialization, every column in the schema will be written out. If there are columns
    /// that should not be saved, this estimator can be used to remove them.
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]></format>
    /// </remarks>
    /// <seealso cref="TransformExtensionsCatalog.DropColumns(TransformsCatalog, string[])"/>
    /// <seealso cref="TransformExtensionsCatalog.SelectColumns(TransformsCatalog, string[])"/>
    /// <seealso cref="TransformExtensionsCatalog.SelectColumns(TransformsCatalog, string[], bool)"/>
    public sealed class ColumnSelectingEstimator : TrivialEstimator<ColumnSelectingTransformer>
    {
        [BestFriend]
        internal static class Defaults
        {
            public const bool KeepHidden = false;
            public const bool IgnoreMissing = false;
        };
 
        private readonly Func<string, bool> _selectPredicate;
 
        /// <summary>
        /// Constructs the Select Columns Estimator.
        /// </summary>
        /// <param name="env">Instance of the host environment.</param>
        /// <param name="keepColumns">The array of column names to keep.</param>
        private ColumnSelectingEstimator(IHostEnvironment env, params string[] keepColumns)
            : this(env, keepColumns, null, Defaults.KeepHidden, Defaults.IgnoreMissing)
        { }
 
        /// <summary>
        /// Constructs the Select Columns Estimator.
        /// </summary>
        /// <param name="env">Instance of the host environment.</param>
        /// <param name="keepColumns">The array of column names to keep, cannot be set with <paramref name="dropColumns"/>.</param>
        /// <param name="dropColumns">The array of column names to drop, cannot be set with <paramref name="keepColumns"/>.</param>
        /// <param name="keepHidden">If true will keep hidden columns and false will remove hidden columns. The argument is
        /// ignored if the Estimator is in "drop mode".</param>
        /// <param name="ignoreMissing">If false will check for any columns given in <paramref name="keepColumns"/>
        ///     or <paramref name="dropColumns"/> that are missing from the input. If a missing column exists a
        ///     SchemaMistmatch exception is thrown. If true, the check is not made.</param>
        internal ColumnSelectingEstimator(IHostEnvironment env, string[] keepColumns,
                                    string[] dropColumns, bool keepHidden = Defaults.KeepHidden,
                                    bool ignoreMissing = Defaults.IgnoreMissing)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnSelectingEstimator)),
                  new ColumnSelectingTransformer(env, keepColumns, dropColumns, keepHidden, ignoreMissing))
        {
 
            _selectPredicate = (name) => (keepColumns != null) ? keepColumns.Contains(name) : !dropColumns.Contains(name);
        }
 
        /// <summary>
        /// KeepColumns is used to select a list of columns that the user wants to keep on a given an input. Any column not specified
        /// will be dropped from the output output schema.
        /// </summary>
        /// <param name="env">Instance of the host environment.</param>
        /// <param name="columnsToKeep">The array of column names to keep.</param>
        internal static ColumnSelectingEstimator KeepColumns(IHostEnvironment env, params string[] columnsToKeep)
        {
            return new ColumnSelectingEstimator(env, columnsToKeep);
        }
 
        /// <summary>
        /// DropColumns is used to select a list of columns that user wants to drop from a given input. Any column not specified will
        /// be maintained in the output schema.
        /// </summary>
        /// <param name="env">Instance of the host environment.</param>
        /// <param name="columnsToDrop">The array of column names to drop.</param>
        internal static ColumnSelectingEstimator DropColumns(IHostEnvironment env, params string[] columnsToDrop)
        {
            return new ColumnSelectingEstimator(env, null, columnsToDrop);
 
        }
 
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            if (!Transformer.IgnoreMissing && !ColumnSelectingTransformer.IsSchemaValid(inputSchema.Select(x => x.Name),
                                                                                    Transformer.SelectColumns,
                                                                                    out IEnumerable<string> invalidColumns))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", string.Join(",", invalidColumns));
            }
 
            var columns = inputSchema.Where(c => _selectPredicate(c.Name));
            return new SchemaShape(columns);
        }
    }
 
    /// <summary>
    /// <see cref="ITransformer"/> resulting from fitting an <see cref="ColumnSelectingEstimator"/>.
    /// </summary>
    public sealed class ColumnSelectingTransformer : ITransformer
    {
        internal const string Summary = "Selects which columns from the dataset to keep.";
        internal const string UserName = "Select Columns Transform";
        internal const string ShortName = "Select";
        internal const string LoaderSignature = "SelectColumnsTransform";
 
        // Back-compat signatures to support loading Drop/Keep and Choose Transforms
        internal const string DropLoaderSignature = "DropColumnsTransform";
        internal const string ChooseLoaderSignature = "ChooseColumnsTransform";
        internal const string ChooseLoaderSignatureOld = "ChooseColumnsFunction";
 
        private readonly IHost _host;
        private readonly string[] _selectedColumns;
 
        bool ITransformer.IsRowToRowMapper => true;
 
        internal IEnumerable<string> SelectColumns => _selectedColumns.AsReadOnly();
 
        internal bool KeepColumns { get; }
 
        internal bool KeepHidden { get; }
        internal bool IgnoreMissing { get; }
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "SELCOLST",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(ColumnSelectingTransformer).Assembly.FullName);
        }
 
        private static VersionInfo GetDropVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "DRPCOLST",
                // verWrittenCur: 0x00010001, // Initial
                verWrittenCur: 0x00010002, // Added KeepColumns
                verReadableCur: 0x00010002,
                verWeCanReadBack: 0x00010002,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(ColumnSelectingTransformer).Assembly.FullName);
        }
 
        private static VersionInfo GetChooseVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "CHSCOLSF",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: ChooseLoaderSignature,
                loaderSignatureAlt: ChooseLoaderSignatureOld,
                loaderAssemblyName: typeof(ColumnSelectingTransformer).Assembly.FullName);
        }
 
        internal sealed class Options : TransformInputBase
        {
            [Argument(ArgumentType.Multiple, HelpText = "List of columns to keep.", ShortName = "keepcol", SortOrder = 1)]
            public string[] KeepColumns;
 
            [Argument(ArgumentType.Multiple, HelpText = "List of columns to drop.", ShortName = "dropcol", SortOrder = 2)]
            public string[] DropColumns;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Specifies whether to keep or remove hidden columns.", ShortName = "hidden", SortOrder = 3)]
            public bool KeepHidden = ColumnSelectingEstimator.Defaults.KeepHidden;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Specifies whether to ignore columns that are missing from the input.", ShortName = "ignore", SortOrder = 4)]
            public bool IgnoreMissing = ColumnSelectingEstimator.Defaults.IgnoreMissing;
        }
 
        internal ColumnSelectingTransformer(IHostEnvironment env, string[] keepColumns, string[] dropColumns,
                                        bool keepHidden = ColumnSelectingEstimator.Defaults.KeepHidden, bool ignoreMissing = ColumnSelectingEstimator.Defaults.IgnoreMissing)
        {
            _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnSelectingTransformer));
            _host.CheckValueOrNull(keepColumns);
            _host.CheckValueOrNull(dropColumns);
 
            bool keepValid = Utils.Size(keepColumns) > 0;
            bool dropValid = Utils.Size(dropColumns) > 0;
 
            // Check that both are not valid
            _host.Check(!(keepValid && dropValid), "Both " + nameof(keepColumns) + " and " + nameof(dropColumns) + " are set. Exactly one can be specified.");
            // Check that both are invalid
            _host.Check(!(!keepValid && !dropValid), "Neither " + nameof(keepColumns) + " and " + nameof(dropColumns) + " is set. Exactly one must be specified.");
 
            _selectedColumns = (keepValid) ? keepColumns : dropColumns;
            KeepColumns = keepValid;
            KeepHidden = keepHidden;
            IgnoreMissing = ignoreMissing;
        }
 
        /// <summary>
        /// Helper function to determine the model version that is being loaded.
        /// </summary>
        private static bool CheckModelVersion(ModelLoadContext ctx, VersionInfo versionInfo)
        {
            try
            {
                ctx.CheckVersionInfo(versionInfo);
                return true;
            }
            catch (Exception)
            {
                //consume
                return false;
            }
        }
 
        /// <summary>
        /// Back-compatibility function that handles loading the DropColumns Transform.
        /// </summary>
        private static ColumnSelectingTransformer LoadDropColumnsTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            // *** Binary format ***
            // int: sizeof(Float)
            // bindings
            int cbFloat = ctx.Reader.ReadInt32();
            env.CheckDecode(cbFloat == sizeof(float));
 
            // *** Binary format ***
            // bool: whether to keep (vs drop) the named columns
            // int: number of names
            // int[]: the ids of the names
            var keep = ctx.Reader.ReadBoolByte();
            int count = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(count > 0);
 
            var names = new HashSet<string>();
            for (int i = 0; i < count; i++)
            {
                string name = ctx.LoadNonEmptyString();
                Contracts.CheckDecode(names.Add(name));
            }
 
            string[] keepColumns = null;
            string[] dropColumns = null;
            if (keep)
                keepColumns = names.ToArray();
            else
                dropColumns = names.ToArray();
 
            // Note for backward compatibility, Drop/Keep Columns always preserves
            // hidden columns
            return new ColumnSelectingTransformer(env, keepColumns, dropColumns, true);
        }
 
        /// <summary>
        /// Back-compatibility that is handling the HiddenColumnOption from ChooseColumns.
        /// </summary>
        private enum HiddenColumnOption : byte
        {
            Drop = 1,
            Keep = 2,
            Rename = 3
        };
 
        /// <summary>
        /// Backwards compatibility helper function to convert the HiddenColumnOption to a boolean.
        /// </summary>
        private static bool GetHiddenOption(IHostEnvironment env, HiddenColumnOption option)
        {
            switch (option)
            {
                case HiddenColumnOption.Keep:
                    return true;
                case HiddenColumnOption.Drop:
                    return false;
                default:
                    throw env.Except("Unsupported hide option specified");
            }
        }
 
        /// <summary>
        /// Backwards compatibility helper function that loads a Choose Column Transform.
        /// </summary>
        private static ColumnSelectingTransformer LoadChooseColumnsTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            // *** Binary format ***
            // int: sizeof(Float)
            // bindings
            int cbFloat = ctx.Reader.ReadInt32();
            string renameNotSupportedMsg = "Rename for ChooseColumns is not backwards compatible with the SelectColumnsTranform";
            string differentHideColumnNotSupportedMsg = "Setting a hide option different from default is not compatible with SelectColumnsTransform";
            // *** Binary format ***
            // byte: default HiddenColumnOption value
            // int: number of raw column infos
            // for each raw column info
            //   int: id of output column name
            //   int: id of input column name
            //   byte: HiddenColumnOption
            var hiddenOption = (HiddenColumnOption)ctx.Reader.ReadByte();
            Contracts.Assert(Enum.IsDefined(typeof(HiddenColumnOption), hiddenOption));
            env.Check(HiddenColumnOption.Rename != hiddenOption, renameNotSupportedMsg);
            var keepHidden = GetHiddenOption(env, hiddenOption);
 
            int count = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(count >= 0);
            var keepHiddenCols = new HiddenColumnOption[count];
 
            var names = new HashSet<string>();
            for (int colIdx = 0; colIdx < count; ++colIdx)
            {
                string dst = ctx.LoadNonEmptyString();
                Contracts.CheckDecode(names.Add(dst));
                string src = ctx.LoadNonEmptyString();
 
                var colHiddenOption = (HiddenColumnOption)ctx.Reader.ReadByte();
                Contracts.Assert(Enum.IsDefined(typeof(HiddenColumnOption), colHiddenOption));
                env.Check(colHiddenOption != HiddenColumnOption.Rename, renameNotSupportedMsg);
                var colKeepHidden = GetHiddenOption(env, colHiddenOption);
                env.Check(colKeepHidden == keepHidden, differentHideColumnNotSupportedMsg);
            }
 
            return new ColumnSelectingTransformer(env, names.ToArray(), null, keepHidden);
        }
 
        // Factory method for SignatureLoadModelTransform.
        private static ColumnSelectingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            ctx.CheckAtModel(GetVersionInfo());
            // *** Binary format ***
            // bool: keep columns flag
            // bool: keep hidden flag
            // bool: ignore missing flag
            // int: number of added columns
            // for each added column
            //   string: selected column name
            var keepColumns = ctx.Reader.ReadBoolByte();
            var keepHidden = ctx.Reader.ReadBoolByte();
            var ignoreMissing = ctx.Reader.ReadBoolByte();
            var length = ctx.Reader.ReadInt32();
            var columns = new string[length];
            for (int i = 0; i < length; i++)
            {
                columns[i] = ctx.LoadNonEmptyString();
            }
 
            string[] columnsToKeep = null;
            string[] columnsToDrop = null;
            if (keepColumns)
                columnsToKeep = columns;
            else
                columnsToDrop = columns;
 
            return new ColumnSelectingTransformer(env, columnsToKeep, columnsToDrop, keepHidden, ignoreMissing);
        }
 
        // Factory method for SignatureLoadDataTransform.
        private static IDataView Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ColumnSelectingTransformer transform;
 
            // Determine which version of the transform is being loaded.
            if (CheckModelVersion(ctx, GetDropVersionInfo()))
            {
                transform = LoadDropColumnsTransform(env, ctx, input);
            }
            else if (CheckModelVersion(ctx, GetChooseVersionInfo()))
            {
                transform = LoadChooseColumnsTransform(env, ctx, input);
            }
            else
            {
                transform = Create(env, ctx);
            }
 
            return transform.Transform(input);
        }
 
        [BestFriend]
        internal static IDataView CreateKeep(IHostEnvironment env, IDataView input, string[] keepColumns, bool keepHidden = false)
            => new ColumnSelectingTransformer(env, keepColumns, null, keepHidden).Transform(input);
 
        [BestFriend]
        internal static IDataView CreateDrop(IHostEnvironment env, IDataView input, params string[] dropColumns)
            => new ColumnSelectingTransformer(env, null, dropColumns).Transform(input);
 
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            var transform = new ColumnSelectingTransformer(env, options.KeepColumns, options.DropColumns,
                                                            options.KeepHidden, options.IgnoreMissing);
            return new SelectColumnsDataTransform(env, transform, new Mapper(transform, input.Schema), input);
        }
 
        void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
 
        internal void SaveModel(ModelSaveContext ctx)
        {
            ctx.SetVersionInfo(GetVersionInfo());
 
            ctx.Writer.WriteBoolByte(KeepColumns);
            ctx.Writer.WriteBoolByte(KeepHidden);
            ctx.Writer.WriteBoolByte(IgnoreMissing);
            var length = _selectedColumns.Length;
            ctx.Writer.Write(length);
            for (int i = 0; i < length; i++)
                ctx.SaveNonEmptyString(_selectedColumns[i]);
        }
 
        internal static bool IsSchemaValid(IEnumerable<string> inputColumns,
                                         IEnumerable<string> selectColumns,
                                         out IEnumerable<string> invalidColumns)
        {
            // Confirm that all selected columns are in the inputSchema
            var missing = selectColumns.Where(x => !inputColumns.Contains(x));
            invalidColumns = missing;
            return missing.Count() == 0;
        }
 
        /// <summary>
        /// Schema propagation for transformers.
        /// Returns the output schema of the data, if the input schema is like the one provided.
        /// </summary>
        public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            if (!IgnoreMissing && !IsSchemaValid(inputSchema.Select(x => x.Name),
                                                                SelectColumns, out IEnumerable<string> invalidColumns))
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", string.Join(",", invalidColumns));
            }
 
            return new Mapper(this, inputSchema).OutputSchema;
        }
 
        /// <summary>
        /// Constructs a row-to-row mapper based on an input schema. If <see cref="ITransformer.IsRowToRowMapper"/>
        /// is <c>false</c>, then an exception is thrown. If the input schema is in any way
        /// unsuitable for constructing the mapper, an exception should likewise be thrown.
        /// </summary>
        /// <param name="inputSchema">The input schema for which we should get the mapper.</param>
        /// <returns>The row to row mapper.</returns>
        IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            if (!IgnoreMissing && !IsSchemaValid(inputSchema.Select(x => x.Name),
                                                    SelectColumns, out IEnumerable<string> invalidColumns))
            {
                throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", string.Join(",", invalidColumns));
            }
 
            return new SelectColumnsDataTransform(_host, this,
                                                  new Mapper(this, inputSchema),
                                                  new EmptyDataView(_host, inputSchema));
        }
 
        /// <summary>
        /// Take the data in, make transformations, output the data.
        /// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
        /// </summary>
        public IDataView Transform(IDataView input)
        {
            _host.CheckValue(input, nameof(input));
            if (!IgnoreMissing && !IsSchemaValid(input.Schema.Select(x => x.Name),
                                                    SelectColumns, out IEnumerable<string> invalidColumns))
            {
                throw _host.ExceptSchemaMismatch(nameof(input), "input", string.Join(",", invalidColumns));
            }
 
            return new SelectColumnsDataTransform(_host, this, new Mapper(this, input.Schema), input);
        }
 
        private sealed class Mapper
        {
            private readonly IHost _host;
            private readonly DataViewSchema _inputSchema;
            public readonly int[] OutputToInputMap;
 
            public DataViewSchema InputSchema => _inputSchema;
 
            public DataViewSchema OutputSchema { get; }
 
            public Mapper(ColumnSelectingTransformer transform, DataViewSchema inputSchema)
            {
                _host = transform._host.Register(nameof(Mapper));
                _inputSchema = inputSchema;
 
                OutputToInputMap = BuildOutputToInputMap(transform.SelectColumns,
                                                            transform.KeepColumns,
                                                            transform.KeepHidden,
                                                            _inputSchema);
                OutputSchema = GenerateOutputSchema(OutputToInputMap, _inputSchema);
            }
 
            public int GetInputIndex(int outputIndex)
            {
                _host.Assert(0 <= outputIndex && outputIndex < OutputToInputMap.Length);
                return OutputToInputMap[outputIndex];
            }
 
            private static int[] BuildOutputToInputMap(IEnumerable<string> selectedColumns,
                bool keepColumns,
                bool keepHidden,
                DataViewSchema inputSchema)
            {
                var outputToInputMapping = new List<int>();
                var columnCount = inputSchema.Count;
 
                if (keepColumns)
                {
                    // With KeepColumns, the order that is specified is preserved in the mapping.
                    // For example if a given input has the columns of ABC and the select columns are
                    // specified as CA, then the output will be CA.
 
                    // In order to account for keeping hidden columns, build a dictionary of
                    // column name-> list of column indices. This dictionary is used for
                    // building the final mapping.
                    var columnDict = new Dictionary<string, List<int>>();
                    for (int colIdx = 0; colIdx < inputSchema.Count; ++colIdx)
                    {
                        if (!keepHidden && inputSchema[colIdx].IsHidden)
                            continue;
 
                        var columnName = inputSchema[colIdx].Name;
                        if (columnDict.TryGetValue(columnName, out List<int> columnList))
                            columnList.Add(colIdx);
                        else
                        {
                            columnList = new List<int>();
                            columnList.Add(colIdx);
                            columnDict.Add(columnName, columnList);
                        }
                    }
 
                    // Since the ordering matters, iterate through the selected columns
                    // finding the associated index that should be used.
                    foreach (var columnName in selectedColumns)
                    {
                        if (columnDict.TryGetValue(columnName, out List<int> columnList))
                        {
                            foreach (var colIdx in columnList)
                            {
                                outputToInputMapping.Add(colIdx);
                            }
                        }
                    }
                }
                else
                {
                    // Handles the drop case, removing any columns specified from the input
                    // In the case of drop, the order of the output is modeled after the input
                    // given an input of ABC and dropping column B will result in AC.
                    // In drop mode, we drop all columns with the specified names and keep all the rest,
                    // ignoring the keepHidden argument.
                    for (int colIdx = 0; colIdx < inputSchema.Count; colIdx++)
                    {
                        if (selectedColumns.Contains(inputSchema[colIdx].Name))
                            continue;
 
                        outputToInputMapping.Add(colIdx);
                    }
                }
 
                return outputToInputMapping.ToArray();
            }
 
            private static DataViewSchema GenerateOutputSchema(IEnumerable<int> map,
                                                        DataViewSchema inputSchema)
            {
                var outputColumns = map.Select(x => new DataViewSchema.DetachedColumn(inputSchema[x]));
                return SchemaExtensions.MakeSchema(outputColumns);
            }
        }
 
        private sealed class RowImpl : WrappingRow
        {
            private readonly Mapper _mapper;
            public RowImpl(DataViewRow input, Mapper mapper)
                : base(input)
            {
                _mapper = mapper;
            }
 
            public override DataViewSchema Schema => _mapper.OutputSchema;
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                int index = _mapper.GetInputIndex(column.Index);
                return Input.GetGetter<TValue>(Input.Schema[index]);
            }
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column) => true;
        }
 
        private sealed class SelectColumnsDataTransform : IDataTransform, IRowToRowMapper, ITransformTemplate, ITransformCanSaveOnnx
        {
            private readonly IHost _host;
            private readonly ColumnSelectingTransformer _transform;
            private readonly Mapper _mapper;
 
            public SelectColumnsDataTransform(IHostEnvironment env, ColumnSelectingTransformer transform, Mapper mapper, IDataView input)
            {
                _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(SelectColumnsDataTransform));
                _transform = transform;
                _mapper = mapper;
                Source = input;
            }
 
            public bool CanShuffle => Source.CanShuffle;
 
            public IDataView Source { get; }
 
            public DataViewSchema InputSchema => Source.Schema;
 
            DataViewSchema IDataView.Schema => OutputSchema;
 
            public DataViewSchema OutputSchema => _mapper.OutputSchema;
 
            public long? GetRowCount() => Source.GetRowCount();
 
            public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
            {
                _host.AssertValueOrNull(rand);
 
                // Build out the active state for the input
                var inputCols = ((IRowToRowMapper)this).GetDependencies(columnsNeeded);
                var inputRowCursor = Source.GetRowCursor(inputCols, rand);
 
                // Build the active state for the output
                var active = Utils.BuildArray(_mapper.OutputSchema.Count, columnsNeeded);
                return new Cursor(_host, _mapper, inputRowCursor, active);
            }
 
            public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
            {
                _host.CheckValueOrNull(rand);
 
                // Build out the active state for the input
                var inputCols = ((IRowToRowMapper)this).GetDependencies(columnsNeeded);
                var inputs = Source.GetRowCursorSet(inputCols, n, rand);
 
                // Build out the active state for the output
                var active = Utils.BuildArray(_mapper.OutputSchema.Count, columnsNeeded);
                _host.AssertNonEmpty(inputs);
 
                // No need to split if this is given 1 input cursor.
                var cursors = new DataViewRowCursor[inputs.Length];
                for (int i = 0; i < inputs.Length; i++)
                    cursors[i] = new Cursor(_host, _mapper, inputs[i], active);
                return cursors;
            }
 
            void ICanSaveModel.Save(ModelSaveContext ctx) => _transform.SaveModel(ctx);
 
            /// <summary>
            /// Given a set of columns, return the input columns that are needed to generate those output columns.
            /// </summary>
            IEnumerable<DataViewSchema.Column> IRowToRowMapper.GetDependencies(IEnumerable<DataViewSchema.Column> columns)
            {
                var active = new bool[_mapper.InputSchema.Count];
                foreach (var column in columns)
                    active[_mapper.GetInputIndex(column.Index)] = true;
 
                return _mapper.InputSchema.Where(col => col.Index < active.Length && active[col.Index]);
            }
 
            DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
                => new RowImpl(input, _mapper);
 
            IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource)
                => new SelectColumnsDataTransform(env, _transform, new Mapper(_transform, newSource.Schema), newSource);
 
            public bool CanSaveOnnx(OnnxContext ctx) => true;
 
            public void SaveAsOnnx(OnnxContext ctx)
            {
                const int minimumOpSetVersion = 9;
                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
 
                var outputToInputMap = _mapper.OutputToInputMap;
                for (int i = 0; i < outputToInputMap.Length; i++)
                {
                    var srcCol = InputSchema[outputToInputMap[i]];
                    var dstCol = OutputSchema[i];
                    if (!ctx.ContainsColumn(srcCol.Name) || dstCol.IsHidden)
                        continue;
 
                    var srcVariable = ctx.GetVariableName(srcCol.Name);
                    var dstVariable = ctx.AddIntermediateVariable(dstCol.Type, dstCol.Name);
                    string opType = "Identity";
                    ctx.CreateNode(opType, srcVariable, dstVariable, ctx.GetNodeName(opType), "");
                }
            }
        }
 
        private sealed class Cursor : SynchronizedCursorBase
        {
            private readonly Mapper _mapper;
            private readonly DataViewRowCursor _inputCursor;
            private readonly bool[] _active;
            public Cursor(IChannelProvider provider, Mapper mapper, DataViewRowCursor input, bool[] active)
                : base(provider, input)
            {
                _mapper = mapper;
                _inputCursor = input;
                _active = active;
            }
 
            public override DataViewSchema Schema => _mapper.OutputSchema;
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                int index = _mapper.GetInputIndex(column.Index);
                return _inputCursor.GetGetter<TValue>(_inputCursor.Schema[index]);
            }
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column) => _active[column.Index];
        }
    }
}