File: OptionalColumnTransform.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(OptionalColumnTransform.Summary, typeof(OptionalColumnTransform),
    typeof(OptionalColumnTransform.Arguments), typeof(SignatureDataTransform),
    OptionalColumnTransform.UserName, OptionalColumnTransform.LoaderSignature, OptionalColumnTransform.ShortName)]
 
[assembly: LoadableClass(typeof(OptionalColumnTransform), null, typeof(SignatureLoadDataTransform),
    OptionalColumnTransform.UserName, OptionalColumnTransform.LoaderSignature)]
 
[assembly: EntryPointModule(typeof(OptionalColumnTransform))]
 
namespace Microsoft.ML.Transforms
{
    /// <include file='doc.xml' path='doc/members/member[@name="OptionalColumnTransform"]/*' />
    [BestFriend]
    internal sealed class OptionalColumnTransform : RowToRowMapperTransformBase, ITransformCanSaveOnnx
    {
        public sealed class Arguments : TransformInputBase
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public string[] Columns;
        }
 
        private sealed class Bindings : ColumnBindingsBase
        {
            public readonly DataViewType[] ColumnTypes;
            public readonly int[] SrcCols;
 
            private readonly MetadataDispatcher _metadata;
            private readonly OptionalColumnTransform _parent;
            // The input schema of the original data view that contains the source columns. We need this
            // so that we can have the metadata even when we load this transform with new data that does not have
            // these columns.
            private readonly DataViewSchema _inputWithOptionalColumn;
            private readonly int[] _srcColsWithOptionalColumn;
 
            private Bindings(OptionalColumnTransform parent, DataViewType[] columnTypes, int[] srcCols,
                int[] srcColsWithOptionalColumn, DataViewSchema input, DataViewSchema inputWithOptionalColumn, bool user, string[] names)
                : base(input, user, names)
            {
                Contracts.AssertValue(parent);
                Contracts.Assert(Utils.Size(columnTypes) == InfoCount);
                Contracts.Assert(Utils.Size(srcCols) == InfoCount);
                Contracts.AssertValue(inputWithOptionalColumn);
                ColumnTypes = columnTypes;
                SrcCols = srcCols;
                _parent = parent;
                _metadata = new MetadataDispatcher(InfoCount);
                _inputWithOptionalColumn = inputWithOptionalColumn;
                _srcColsWithOptionalColumn = srcColsWithOptionalColumn;
                SetMetadata();
            }
 
            public static Bindings Create(Arguments args, DataViewSchema input, OptionalColumnTransform parent)
            {
                var names = new string[args.Columns.Length];
                var columnTypes = new DataViewType[args.Columns.Length];
                var srcCols = new int[args.Columns.Length];
                for (int i = 0; i < args.Columns.Length; i++)
                {
                    names[i] = args.Columns[i];
                    int col;
                    bool success = input.TryGetColumnIndex(names[i], out col);
                    Contracts.CheckUserArg(success, nameof(args.Columns));
                    columnTypes[i] = input[col].Type;
                    srcCols[i] = col;
                }
 
                return new Bindings(parent, columnTypes, srcCols, srcCols, input, input, true, names);
            }
 
            public static Bindings Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema input, OptionalColumnTransform parent)
            {
                Contracts.AssertValue(ctx);
                Contracts.AssertValue(input);
 
                // *** Binary format ***
                // Schema of the data view containing the optional columns
                // int: number of added columns
                // for each added column
                //   int: id of output column name
                //   ColumnType: the type of the column
 
                byte[] buffer = null;
                if (!ctx.TryLoadBinaryStream("Schema.idv", r => buffer = r.ReadByteArray()))
                    throw env.ExceptDecode();
                BinaryLoader loader = null;
                var strm = new MemoryStream(buffer, writable: false);
                loader = new BinaryLoader(env, new BinaryLoader.Arguments(), strm);
 
                int size = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(size > 0);
 
                var saver = new BinarySaver(env, new BinarySaver.Arguments());
                var names = new string[size];
                var columnTypes = new DataViewType[size];
                var srcCols = new int[size];
                var srcColsWithOptionalColumn = new int[size];
                for (int i = 0; i < size; i++)
                {
                    names[i] = ctx.LoadNonEmptyString();
                    columnTypes[i] = saver.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream);
                    int col;
                    bool success = input.TryGetColumnIndex(names[i], out col);
                    srcCols[i] = success ? col : -1;
 
                    success = loader.Schema.TryGetColumnIndex(names[i], out var colWithOptionalColumn);
                    env.CheckDecode(success);
                    srcColsWithOptionalColumn[i] = colWithOptionalColumn;
                }
 
                return new Bindings(parent, columnTypes, srcCols, srcColsWithOptionalColumn, input, loader.Schema, false, names);
            }
 
            public void Save(IHostEnvironment env, ModelSaveContext ctx)
            {
                Contracts.AssertValue(ctx);
 
                // *** Binary format ***
                // Schema of the data view containing the optional columns
                // int: number of added columns
                // for each added column
                //   int: id of output column name
                //   ColumnType: the type of the column
 
                var noRows = new EmptyDataView(env, _inputWithOptionalColumn);
                var saverArgs = new BinarySaver.Arguments();
                saverArgs.Silent = true;
                var saver = new BinarySaver(env, saverArgs);
                using (var strm = new MemoryStream())
                {
                    saver.SaveData(strm, noRows, _srcColsWithOptionalColumn);
                    ctx.SaveBinaryStream("Schema.idv", w => w.WriteByteArray(strm.ToArray()));
                }
 
                int size = InfoCount;
                ctx.Writer.Write(size);
 
                saver = new BinarySaver(env, new BinarySaver.Arguments());
                for (int i = 0; i < size; i++)
                {
                    ctx.SaveNonEmptyString(GetColumnNameCore(i));
                    var columnType = ColumnTypes[i];
                    int written;
                    saver.TryWriteTypeDescription(ctx.Writer.BaseStream, columnType, out written);
                }
            }
 
            private void SetMetadata()
            {
                var md = _metadata;
                for (int iinfo = 0; iinfo < InfoCount; iinfo++)
                {
                    // Pass through metadata from source columns.
                    using (var bldr = md.BuildMetadata(iinfo, _inputWithOptionalColumn, _srcColsWithOptionalColumn[iinfo]))
                    {
                        // No metadata to add.
                    }
                }
                md.Seal();
            }
 
            protected override DataViewType GetColumnTypeCore(int iinfo)
            {
                Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
                return ColumnTypes[iinfo];
            }
 
            protected override IEnumerable<KeyValuePair<string, DataViewType>> GetAnnotationTypesCore(int iinfo)
            {
                return _metadata.GetMetadataTypes(iinfo);
            }
 
            protected override DataViewType GetAnnotationTypeCore(string kind, int iinfo)
            {
                return _metadata.GetMetadataTypeOrNull(kind, iinfo);
            }
 
            protected override void GetAnnotationCore<TValue>(string kind, int iinfo, ref TValue value)
            {
                _metadata.GetMetadata(_parent.Host, kind, iinfo, ref value);
            }
 
            public Func<int, bool> GetDependencies(Func<int, bool> predicate)
            {
                Contracts.AssertValue(predicate);
 
                var active = GetActiveInput(predicate);
                Contracts.Assert(active.Length == Input.Count);
 
                foreach (int srcCol in SrcCols)
                {
                    if (srcCol >= 0)
                        active[srcCol] = true;
                }
 
                return col => 0 <= col && col < active.Length && active[col];
            }
 
            /// <summary>
            /// Given a set of columns, return the input columns that are needed to generate those output columns.
            /// </summary>
            public IEnumerable<DataViewSchema.Column> GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns)
            {
                Contracts.AssertValue(dependingColumns);
                var predicate = RowCursorUtils.FromColumnsToPredicate(dependingColumns, AsSchema);
                Func<int, bool> dependencies = GetDependencies(predicate);
 
                return Input.Where(c => dependencies(c.Index));
            }
        }
 
        internal const string Summary = "If the source column does not exist after deserialization," +
            " create a column with the right type and default values.";
        internal const string UserName = "Optional Column Transform";
        public const string LoaderSignature = "OptColTransform";
        internal const string ShortName = "optional";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "OPTCOL T",
                //verWrittenCur: 0x00010001, // Initial
                verWrittenCur: 0x00010002, // Save the input schema, for metadata
                verReadableCur: 0x00010002,
                verWeCanReadBack: 0x00010002,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(OptionalColumnTransform).Assembly.FullName);
        }
 
        private static readonly FuncInstanceMethodInfo1<OptionalColumnTransform, DataViewRow, int, Delegate> _getSrcGetterMethodInfo
            = FuncInstanceMethodInfo1<OptionalColumnTransform, DataViewRow, int, Delegate>.Create(target => target.GetSrcGetter<int>);
 
        private static readonly FuncInstanceMethodInfo1<OptionalColumnTransform, Delegate> _makeGetterOneMethodInfo
            = FuncInstanceMethodInfo1<OptionalColumnTransform, Delegate>.Create(target => target.MakeGetterOne<int>);
 
        private static readonly FuncInstanceMethodInfo1<OptionalColumnTransform, int, Delegate> _makeGetterVecMethodInfo
            = FuncInstanceMethodInfo1<OptionalColumnTransform, int, Delegate>.Create(target => target.MakeGetterVec<int>);
 
        private readonly Bindings _bindings;
 
        private const string RegistrationName = "OptionalColumn";
 
        /// <summary>
        /// Initializes a new instance of <see cref="OptionalColumnTransform"/>.
        /// </summary>
        /// <param name="env">Host Environment.</param>
        /// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
        /// <param name="columns">Columns to transform.</param>
        public OptionalColumnTransform(IHostEnvironment env, IDataView input, params string[] columns)
            : this(env, new Arguments() { Columns = columns }, input)
        {
        }
 
        /// <summary>
        /// Public constructor corresponding to SignatureDataTransform.
        /// </summary>
        public OptionalColumnTransform(IHostEnvironment env, Arguments args, IDataView input)
            : base(env, RegistrationName, input)
        {
            Host.CheckValue(args, nameof(args));
            Host.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns));
 
            _bindings = Bindings.Create(args, Source.Schema, this);
        }
 
        private OptionalColumnTransform(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            Host.AssertValue(ctx);
 
            // *** Binary format ***
            // bindings
            _bindings = Bindings.Create(host, ctx, Source.Schema, this);
        }
 
        public static OptionalColumnTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(ctx, nameof(ctx));
            h.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());
            return h.Apply("Loading Model", ch => new OptionalColumnTransform(h, ctx, input));
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // bindings
            _bindings.Save(Host, ctx);
        }
 
        public override DataViewSchema OutputSchema => _bindings.AsSchema;
 
        protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
        {
            Host.AssertValue(predicate, "predicate");
            return null;
        }
 
        protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            Host.AssertValueOrNull(rand);
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
 
            var inputPred = _bindings.GetDependencies(predicate);
            var active = _bindings.GetActive(predicate);
 
            var inputCols = Source.Schema.Where(x => inputPred(x.Index));
            var input = Source.GetRowCursor(inputCols);
            return new Cursor(Host, _bindings, input, active);
        }
 
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.CheckValueOrNull(rand);
 
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
            var inputPred = _bindings.GetDependencies(predicate);
            var inputCols = Source.Schema.Where(x => inputPred(x.Index));
 
            var active = _bindings.GetActive(predicate);
            DataViewRowCursor input;
 
            if (n > 1 && ShouldUseParallelCursors(predicate) != false)
            {
                var inputs = Source.GetRowCursorSet(inputCols, n);
                Host.AssertNonEmpty(inputs);
 
                if (inputs.Length != 1)
                {
                    var cursors = new DataViewRowCursor[inputs.Length];
                    for (int i = 0; i < inputs.Length; i++)
                        cursors[i] = new Cursor(Host, _bindings, inputs[i], active);
                    return cursors;
                }
                input = inputs[0];
            }
            else
                input = Source.GetRowCursor(inputCols);
 
            return new DataViewRowCursor[] { new Cursor(Host, _bindings, input, active) };
        }
 
        protected override IEnumerable<DataViewSchema.Column> GetDependenciesCore(IEnumerable<DataViewSchema.Column> dependingColumns)
            => _bindings.GetDependencies(dependingColumns);
 
        protected override int MapColumnIndex(out bool isSrc, int col)
        {
            return _bindings.MapColumnIndex(out isSrc, col);
        }
 
        protected override Delegate[] CreateGetters(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns, out Action disposer)
        {
            var activeIndices = new HashSet<int>(activeColumns.Select(c => c.Index));
            Func<int, bool> activeInfos =
                iinfo =>
                {
                    int col = _bindings.MapIinfoToCol(iinfo);
                    return activeIndices.Contains(col);
                };
 
            var getters = new Delegate[_bindings.InfoCount];
            disposer = null;
            using (var ch = Host.Start("CreateGetters"))
            {
                for (int iinfo = 0; iinfo < _bindings.InfoCount; iinfo++)
                {
                    if (!activeInfos(iinfo))
                        continue;
                    if (_bindings.SrcCols[iinfo] < 0)
                        getters[iinfo] = MakeGetter(iinfo);
                    else
                    {
                        getters[iinfo] = Utils.MarshalInvoke(_getSrcGetterMethodInfo, this, _bindings.ColumnTypes[iinfo].GetItemType().RawType, input, iinfo);
                    }
                }
                return getters;
            }
        }
 
        private ValueGetter<T> GetSrcGetter<T>(DataViewRow input, int iinfo)
        {
            return input.GetGetter<T>(input.Schema[_bindings.SrcCols[iinfo]]);
        }
 
        private Delegate MakeGetter(int iinfo)
        {
            var columnType = _bindings.ColumnTypes[iinfo];
            if (columnType is VectorDataViewType vectorType)
                return Utils.MarshalInvoke(_makeGetterVecMethodInfo, this, vectorType.ItemType.RawType, vectorType.Size);
            return Utils.MarshalInvoke(_makeGetterOneMethodInfo, this, columnType.RawType);
        }
 
        private Delegate MakeGetterOne<T>()
        {
            return (ValueGetter<T>)((ref T value) => value = default(T));
        }
 
        private Delegate MakeGetterVec<T>(int length)
        {
            return (ValueGetter<VBuffer<T>>)((ref VBuffer<T> value) =>
                VBufferUtils.Resize(ref value, length, 0));
        }
 
        private sealed class Cursor : SynchronizedCursorBase
        {
            private static readonly FuncInstanceMethodInfo1<Cursor, Delegate> _makeGetterOneMethodInfo
                = FuncInstanceMethodInfo1<Cursor, Delegate>.Create(target => target.MakeGetterOne<int>);
 
            private static readonly FuncInstanceMethodInfo1<Cursor, int, Delegate> _makeGetterVecMethodInfo
                = FuncInstanceMethodInfo1<Cursor, int, Delegate>.Create(target => target.MakeGetterVec<int>);
 
            private readonly Bindings _bindings;
            private readonly bool[] _active;
            private readonly Delegate[] _getters;
 
            public Cursor(IChannelProvider provider, Bindings bindings, DataViewRowCursor input, bool[] active)
                : base(provider, input)
            {
                Ch.CheckValue(bindings, nameof(bindings));
                Ch.CheckValue(input, nameof(input));
                Ch.CheckParam(active == null || active.Length == bindings.ColumnCount, nameof(active));
 
                _bindings = bindings;
                _active = active;
                var length = _bindings.InfoCount;
                _getters = new Delegate[length];
                for (int iinfo = 0; iinfo < length; iinfo++)
                {
                    if (_bindings.SrcCols[iinfo] < 0)
                        _getters[iinfo] = MakeGetter(iinfo);
                }
            }
 
            public override DataViewSchema Schema => _bindings.AsSchema;
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                Ch.Check(column.Index < _bindings.ColumnCount);
                return _active == null || _active[column.Index];
            }
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                Ch.Check(IsColumnActive(column));
 
                bool isSrc;
                int index = _bindings.MapColumnIndex(out isSrc, column.Index);
                if (isSrc)
                    return Input.GetGetter<TValue>(Input.Schema[index]);
 
                if (_getters[index] == null)
                    return Input.GetGetter<TValue>(_bindings.AsSchema[_bindings.SrcCols[index]]);
 
                var originFn = _getters[index];
                Ch.Assert(originFn != null);
                var fn = originFn as ValueGetter<TValue>;
                if (fn == null)
                    throw Ch.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}', " +
                            $"expected type: '{originFn.GetType().GetGenericArguments().First()}'.");
                return fn;
            }
 
            private Delegate MakeGetter(int iinfo)
            {
                var columnType = _bindings.ColumnTypes[iinfo];
                if (columnType is VectorDataViewType vectorType)
                    return Utils.MarshalInvoke(_makeGetterVecMethodInfo, this, vectorType.ItemType.RawType, vectorType.Size);
                return Utils.MarshalInvoke(_makeGetterOneMethodInfo, this, columnType.RawType);
            }
 
            private Delegate MakeGetterOne<T>()
            {
                return (ValueGetter<T>)((ref T value) => value = default(T));
            }
 
            private Delegate MakeGetterVec<T>(int length)
            {
                return (ValueGetter<VBuffer<T>>)((ref VBuffer<T> value) =>
                    VBufferUtils.Resize(ref value, length, 0));
            }
        }
 
        public void SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(((ICanSaveOnnx)this).CanSaveOnnx(ctx));
 
            for (int iinfo = 0; iinfo < _bindings.ColumnTypes.Length; ++iinfo)
            {
                var columnType = _bindings.ColumnTypes[iinfo];
                string inputColumnName = Source.Schema[_bindings.SrcCols[iinfo]].Name;
                if (!ctx.ContainsColumn(inputColumnName))
                    continue;
 
                // If there is already a column of this name, don't add this column as an OptionalColumn/Initializer
                var srcVariableName = ctx.GetVariableName(inputColumnName);
                if (srcVariableName != inputColumnName)
                    continue;
 
                if (!SaveAsOnnxCore(ctx, srcVariableName, _bindings.ColumnTypes[iinfo]))
                    ctx.RemoveColumn(inputColumnName, true);
            }
        }
 
        public bool CanSaveOnnx(OnnxContext ctx) => true;
 
        private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, DataViewType columnType)
        {
            const int minimumOpSetVersion = 9;
            ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
 
            Type type = columnType.RawType;
 
            int size;
            if (columnType is VectorDataViewType && columnType.IsKnownSizeVector())
                size = columnType.GetVectorSize();
            else
                size = 1;
 
            if ((type == typeof(int)) ||
                (type == typeof(short)) || (type == typeof(ushort)) ||
                (type == typeof(sbyte)) || (type == typeof(byte)))
                ctx.AddInitializer(new int[size], type, new long[] { 1, size }, srcVariableName, false);
            else if (type == typeof(uint) || (type == typeof(ulong)))
                ctx.AddInitializer(new ulong[size], type == typeof(ulong), new long[] { 1, size }, srcVariableName, false);
            else if (type == typeof(bool))
                ctx.AddInitializer(new bool[size], new long[] { 1, size }, srcVariableName, false);
            else if (type == typeof(long))
                ctx.AddInitializer(new long[size], new long[] { 1, size }, srcVariableName, false);
            else if (type == typeof(float))
                ctx.AddInitializer(new float[size], new long[] { 1, size }, srcVariableName, false);
            else if (type == typeof(double))
                ctx.AddInitializer(new double[size], new long[] { 1, size }, srcVariableName, false);
            else if ((type == typeof(string)) || (columnType is TextDataViewType))
            {
                string[] values = new string[size];
                for (int i = 0; i < size; i++)
                    values[i] = "";
 
                ctx.AddInitializer(values, new long[] { 1, size }, srcVariableName, false);
            }
            else
                return false;
 
            return true;
        }
 
        [TlcModule.EntryPoint(Desc = Summary,
            Name = "Transforms.OptionalColumnCreator",
            UserName = UserName,
            ShortName = ShortName)]
 
        public static CommonOutputs.TransformOutput MakeOptional(IHostEnvironment env, Arguments input)
        {
            var h = EntryPointUtils.CheckArgsAndCreateHost(env, "OptionalColumn", input);
            var xf = new OptionalColumnTransform(h, input, input.Data);
            return new CommonOutputs.TransformOutput()
            {
                Model = new TransformModelImpl(h, xf, input.Data),
                OutputData = xf
            };
        }
    }
}