File: Transforms\ColumnCopying.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(ColumnCopyingTransformer.Summary, typeof(IDataTransform), typeof(ColumnCopyingTransformer),
    typeof(ColumnCopyingTransformer.Options), typeof(SignatureDataTransform),
    ColumnCopyingTransformer.UserName, "CopyColumns", "CopyColumnsTransform", ColumnCopyingTransformer.ShortName,
    DocName = "transform/CopyColumnsTransformer.md")]
 
[assembly: LoadableClass(ColumnCopyingTransformer.Summary, typeof(IDataTransform), typeof(ColumnCopyingTransformer), null, typeof(SignatureLoadDataTransform),
    ColumnCopyingTransformer.UserName, ColumnCopyingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(ColumnCopyingTransformer.Summary, typeof(ColumnCopyingTransformer), null, typeof(SignatureLoadModel),
    ColumnCopyingTransformer.UserName, ColumnCopyingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(ColumnCopyingTransformer), null, typeof(SignatureLoadRowMapper),
    ColumnCopyingTransformer.UserName, ColumnCopyingTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// <see cref="IEstimator{TTransformer}"/> for the <see cref="ColumnCopyingTransformer"/>.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    ///
    /// ###  Estimator Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Does this estimator need to look at the data to train its parameters? | No |
    /// | Input column data type | Any |
    /// | Output column data type | The same as the data type in the input column |
    /// | Exportable to ONNX | Yes |
    ///
    /// The resulting [ColumnCopyingTransformer](xref:Microsoft.ML.Transforms.ColumnCopyingTransformer) creates a new column, named as specified in the output column name parameters, and
    /// copies the data from the input column to this new column.
    /// Check the See Also section for links to usage examples.
    /// ]]>
    /// </format>
    /// </remarks>
    /// <seealso cref="TransformExtensionsCatalog.CopyColumns(TransformsCatalog, string, string)" />
    public sealed class ColumnCopyingEstimator : TrivialEstimator<ColumnCopyingTransformer>
    {
        [BestFriend]
        internal ColumnCopyingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName) :
            this(env, (outputColumnName, inputColumnName))
        {
        }
 
        [BestFriend]
        internal ColumnCopyingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnCopyingEstimator)), new ColumnCopyingTransformer(env, columns))
        {
        }
 
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
 
            var resultDic = inputSchema.ToDictionary(x => x.Name);
            foreach (var (outputColumnName, inputColumnName) in Transformer.Columns)
            {
                if (!inputSchema.TryFindColumn(inputColumnName, out var originalColumn))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnName);
                var col = new SchemaShape.Column(outputColumnName, originalColumn.Kind, originalColumn.ItemType, originalColumn.IsKey, originalColumn.Annotations);
                resultDic[outputColumnName] = col;
            }
            return new SchemaShape(resultDic.Values);
        }
    }
 
    /// <summary>
    /// <see cref="ITransformer"/> resulting from fitting a <see cref="ColumnCopyingEstimator"/>.
    /// </summary>
    public sealed class ColumnCopyingTransformer : OneToOneTransformerBase
    {
        [BestFriend]
        internal const string LoaderSignature = "CopyTransform";
        internal const string Summary = "Copy a source column to a new column.";
        internal const string UserName = "Copy Columns Transform";
        internal const string ShortName = "Copy";
 
        /// <summary>
        /// Names of output and input column pairs on which the transformation is applied.
        /// </summary>
        internal IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "COPYCOLT",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(ColumnCopyingTransformer).Assembly.FullName);
        }
 
        internal ColumnCopyingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnCopyingTransformer)), columns)
        {
        }
 
        internal sealed class Column : OneToOneColumn
        {
            internal static Column Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                return TryUnparseCore(sb);
            }
        }
 
        internal sealed class Options : TransformInputBase
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)",
                Name = "Column", ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
        }
 
        // Factory method corresponding to SignatureDataTransform.
        internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
 
            var transformer = new ColumnCopyingTransformer(env, options.Columns.Select(x => (x.Name, x.Source)).ToArray());
            return transformer.MakeDataTransform(input);
        }
 
        // Factory method for SignatureLoadModel.
        private static ColumnCopyingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            // *** Binary format ***
            // int: number of added columns
            // for each added column
            //   string: output column name
            //   string: input column name
 
            var length = ctx.Reader.ReadInt32();
            var columns = new (string outputColumnName, string inputColumnName)[length];
            for (int i = 0; i < length; i++)
            {
                columns[i].outputColumnName = ctx.LoadNonEmptyString();
                columns[i].inputColumnName = ctx.LoadNonEmptyString();
            }
            return new ColumnCopyingTransformer(env, columns);
        }
 
        // Factory method for SignatureLoadDataTransform.
        private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
            => Create(env, ctx).MakeDataTransform(input);
 
        // Factory method for SignatureLoadRowMapper.
        private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
            => Create(env, ctx).MakeRowMapper(inputSchema);
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            ctx.SetVersionInfo(GetVersionInfo());
            SaveColumns(ctx);
        }
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema)
            => new Mapper(this, inputSchema, ColumnPairs);
 
        private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
        {
            private static readonly FuncStaticMethodInfo1<DataViewRow, int, Delegate> _makeGetterMethodInfo
                = new FuncStaticMethodInfo1<DataViewRow, int, Delegate>(MakeGetter<int>);
 
            private readonly DataViewSchema _schema;
            private readonly (string outputColumnName, string inputColumnName)[] _columns;
 
            public bool CanSaveOnnx(OnnxContext ctx) => true;
 
            internal Mapper(ColumnCopyingTransformer parent, DataViewSchema inputSchema, (string outputColumnName, string inputColumnName)[] columns)
                : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
            {
                _schema = inputSchema;
                _columns = columns;
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
            {
                Host.AssertValue(input);
                Host.Assert(0 <= iinfo && iinfo < _columns.Length);
                disposer = null;
 
                input.Schema.TryGetColumnIndex(_columns[iinfo].inputColumnName, out int colIndex);
                var type = input.Schema[colIndex].Type;
                return Utils.MarshalInvoke(_makeGetterMethodInfo, type.RawType, input, colIndex);
            }
 
            private static Delegate MakeGetter<T>(DataViewRow row, int index)
                => row.GetGetter<T>(row.Schema[index]);
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
                var result = new DataViewSchema.DetachedColumn[_columns.Length];
                for (int i = 0; i < _columns.Length; i++)
                {
                    var srcCol = _schema[_columns[i].inputColumnName];
                    result[i] = new DataViewSchema.DetachedColumn(_columns[i].outputColumnName, srcCol.Type, srcCol.Annotations);
                }
                return result;
            }
 
            public void SaveAsOnnx(OnnxContext ctx)
            {
                const int minimumOpSetVersion = 9;
                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
 
                var opType = "Identity";
 
                foreach (var column in _columns)
                {
                    if (!ctx.ContainsColumn(column.inputColumnName))
                        continue;
 
                    var srcVariableName = ctx.GetVariableName(column.inputColumnName);
                    _schema.TryGetColumnIndex(column.inputColumnName, out int colIndex);
                    var dstVariableName = ctx.AddIntermediateVariable(_schema[colIndex].Type, column.outputColumnName);
                    var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
                }
            }
        }
    }
}