File: DataView\RowToRowMapperTransform.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
 
[assembly: LoadableClass(typeof(RowToRowMapperTransform), null, typeof(SignatureLoadDataTransform),
    "", RowToRowMapperTransform.LoaderSignature)]
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// This interface is used to create a <see cref="RowToRowMapperTransform"/>.
    /// Implementations should be given an <see cref="DataViewSchema"/> in their constructor, and should have a
    /// ctor or Create method with <see cref="SignatureLoadRowMapper"/>, along with a corresponding
    /// <see cref="LoadableClassAttribute"/>.
    /// </summary>
    [BestFriend]
    internal interface IRowMapper : ICanSaveModel
    {
        /// <summary>
        /// Returns the input columns needed for the requested output columns.
        /// </summary>
        Func<int, bool> GetDependencies(Func<int, bool> activeOutput);
 
        /// <summary>
        /// Returns the getters for the output columns given an active set of output columns. The length of the getters
        /// array should be equal to the number of columns added by the IRowMapper. It should contain the getter for the
        /// i'th output column if activeOutput(i) is true, and null otherwise. If creating a <see cref="DataViewRow"/> or
        /// <see cref="DataViewRowCursor"/> out of this, the <paramref name="disposer"/> delegate (if non-null) should be called
        /// from the dispose of either of those instances.
        /// </summary>
        Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer);
 
        /// <summary>
        /// Returns information about the output columns, including their name, type and any metadata information.
        /// </summary>
        DataViewSchema.DetachedColumn[] GetOutputColumns();
 
        /// <summary>
        /// DO NOT USE IT!
        /// Purpose of this method is to enable legacy loading and unwrapping of RowToRowTransform.
        /// It should be removed as soon as we get rid of <see cref="TrainedWrapperEstimatorBase"/>
        /// Returns parent transfomer which uses this mapper.
        /// </summary>
        ITransformer GetTransformer();
    }
    [BestFriend]
    internal delegate void SignatureLoadRowMapper(ModelLoadContext ctx, DataViewSchema schema);
 
    /// <summary>
    /// This class is a transform that can add any number of output columns, that depend on any number of input columns.
    /// It does so with the help of an <see cref="IRowMapper"/>, that is given a schema in its constructor, and has methods
    /// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns.
    /// </summary>
    [BestFriend]
    internal sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper,
        ITransformCanSaveOnnx, ITransformCanSavePfa, ITransformTemplate
    {
        private readonly IRowMapper _mapper;
        private readonly ColumnBindings _bindings;
 
        // If this is not null, the transform is re-appliable without save/load.
        private readonly Func<DataViewSchema, IRowMapper> _mapperFactory;
 
        public const string RegistrationName = "RowToRowMapperTransform";
        public const string LoaderSignature = "RowToRowMapper";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "ROW MPPR",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(RowToRowMapperTransform).Assembly.FullName);
        }
 
        public override DataViewSchema OutputSchema => _bindings.Schema;
 
        bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;
 
        bool ICanSavePfa.CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false;
 
        public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func<DataViewSchema, IRowMapper> mapperFactory)
            : base(env, RegistrationName, input)
        {
            Contracts.CheckValue(mapper, nameof(mapper));
            Contracts.CheckValueOrNull(mapperFactory);
            _mapper = mapper;
            _mapperFactory = mapperFactory;
            _bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns());
        }
        public static DataViewSchema GetOutputSchema(DataViewSchema inputSchema, IRowMapper mapper)
        {
            Contracts.CheckValue(inputSchema, nameof(inputSchema));
            Contracts.CheckValue(mapper, nameof(mapper));
            return new ColumnBindings(inputSchema, mapper.GetOutputColumns()).Schema;
        }
 
        private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            // *** Binary format ***
            // _mapper
 
            ctx.LoadModel<IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
            _bindings = new ColumnBindings(input.Schema, _mapper.GetOutputColumns());
        }
 
        public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            h.CheckValue(input, nameof(input));
            return h.Apply("Loading Model", ch => new RowToRowMapperTransform(h, ctx, input));
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // _mapper
 
            ctx.SaveModel(_mapper, "Mapper");
        }
 
        /// <summary>
        /// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
        /// and the needed active input columns, given a predicate for the needed active output columns.
        /// </summary>
        private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<DataViewSchema.Column> inputColumns)
        {
            int n = _bindings.Schema.Count;
            var active = Utils.BuildArray(n, predicate);
            Contracts.Assert(active.Length == n);
 
            var activeInput = _bindings.GetActiveInput(predicate);
            Contracts.Assert(activeInput.Length == _bindings.InputSchema.Count);
 
            // Get a predicate that determines which outputs are active.
            var predicateOut = GetActiveOutputColumns(active);
 
            // Now map those to active input columns.
            var predicateIn = _mapper.GetDependencies(predicateOut);
 
            // Combine the two sets of input columns.
            inputColumns = _bindings.InputSchema.Where(col => activeInput[col.Index] || predicateIn(col.Index));
 
            return active;
        }
 
        private Func<int, bool> GetActiveOutputColumns(bool[] active)
        {
            Contracts.AssertValue(active);
            Contracts.Assert(active.Length == _bindings.Schema.Count);
 
            return
                col =>
                {
                    Contracts.Assert(0 <= col && col < _bindings.AddedColumnIndices.Count);
                    return 0 <= col && col < _bindings.AddedColumnIndices.Count && active[_bindings.AddedColumnIndices[col]];
                };
        }
 
        protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
        {
            Host.AssertValue(predicate, "predicate");
            if (_bindings.AddedColumnIndices.Any(predicate))
                return true;
            return null;
        }
 
        protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
            var active = GetActive(predicate, out IEnumerable<DataViewSchema.Column> inputCols);
 
            return new Cursor(Host, Source.GetRowCursor(inputCols, rand), this, active);
        }
 
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.CheckValueOrNull(rand);
 
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
            var active = GetActive(predicate, out IEnumerable<DataViewSchema.Column> inputCols);
 
            var inputs = Source.GetRowCursorSet(inputCols, n, rand);
            Host.AssertNonEmpty(inputs);
 
            if (inputs.Length == 1 && n > 1 && _bindings.AddedColumnIndices.Any(predicate))
                inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n);
            Host.AssertNonEmpty(inputs);
 
            var cursors = new DataViewRowCursor[inputs.Length];
            for (int i = 0; i < inputs.Length; i++)
                cursors[i] = new Cursor(Host, inputs[i], this, active);
            return cursors;
        }
 
        void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            if (_mapper is ISaveAsOnnx onnx)
            {
                Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
                onnx.SaveAsOnnx(ctx);
            }
        }
 
        void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            if (_mapper is ISaveAsPfa pfa)
            {
                Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA.");
                pfa.SaveAsPfa(ctx);
            }
        }
 
        /// <summary>
        /// Given a set of output columns, return the input columns that are needed to generate those output columns.
        /// </summary>
        IEnumerable<DataViewSchema.Column> IRowToRowMapper.GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns)
        {
            var predicate = RowCursorUtils.FromColumnsToPredicate(dependingColumns, OutputSchema);
            GetActive(predicate, out var inputColumns);
            return inputColumns;
        }
 
        public DataViewSchema InputSchema => Source.Schema;
 
        DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
        {
            Host.CheckValue(input, nameof(input));
            Host.CheckValue(activeColumns, nameof(activeColumns));
            Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");
 
            using (var ch = Host.Start("GetEntireRow"))
            {
                var activeArr = new bool[OutputSchema.Count];
                foreach (var column in activeColumns)
                {
                    Host.Assert(column.Index < activeArr.Length, $"The columns {activeColumns.Select(c => c.Name)} are not suitable for the OutputSchema.");
                    activeArr[column.Index] = true;
                }
                var pred = GetActiveOutputColumns(activeArr);
                var getters = _mapper.CreateGetters(input, pred, out Action disp);
                return new RowImpl(input, this, OutputSchema, getters, disp);
            }
        }
 
        IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource)
        {
            Contracts.CheckValue(env, nameof(env));
 
            Contracts.CheckValue(newSource, nameof(newSource));
            if (_mapperFactory != null)
            {
                var newMapper = _mapperFactory(newSource.Schema);
                return new RowToRowMapperTransform(env.Register(nameof(RowToRowMapperTransform)), newSource, newMapper, _mapperFactory);
            }
            // Revert to serialization. This was how it worked in all the cases, now it's only when we can't re-create the mapper.
            using (var stream = new MemoryStream())
            {
                using (var rep = RepositoryWriter.CreateNew(stream, env))
                {
                    ModelSaveContext.SaveModel(rep, this, "model");
                    rep.Commit();
                }
 
                stream.Position = 0;
                using (var rep = RepositoryReader.Open(stream, env))
                {
                    IDataTransform newData;
                    ModelLoadContext.LoadModel<IDataTransform, SignatureLoadDataTransform>(env,
                        out newData, rep, "model", newSource);
                    return newData;
                }
            }
        }
 
        private sealed class RowImpl : WrappingRow
        {
            private readonly Delegate[] _getters;
            private readonly RowToRowMapperTransform _parent;
            private readonly Action _disposer;
 
            public override DataViewSchema Schema { get; }
 
            public RowImpl(DataViewRow input, RowToRowMapperTransform parent, DataViewSchema schema, Delegate[] getters, Action disposer)
                : base(input)
            {
                _parent = parent;
                Schema = schema;
                _getters = getters;
                _disposer = disposer;
            }
 
            protected override void DisposeCore(bool disposing)
            {
                if (disposing)
                    _disposer?.Invoke();
            }
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                bool isSrc;
                int index = _parent._bindings.MapColumnIndex(out isSrc, column.Index);
                if (isSrc)
                    return Input.GetGetter<TValue>(Input.Schema[index]);
 
                var originFn = _getters[index];
                Contracts.Assert(originFn != null);
                var fn = originFn as ValueGetter<TValue>;
                if (fn == null)
                    throw Contracts.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}', " +
                        $"expected type: '{originFn.GetType().GetGenericArguments().First()}'.");
                return fn;
            }
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                bool isSrc;
                int index = _parent._bindings.MapColumnIndex(out isSrc, column.Index);
                if (isSrc)
                    return Input.IsColumnActive(Schema[index]);
                return _getters[index] != null;
            }
        }
 
        private sealed class Cursor : SynchronizedCursorBase
        {
            private readonly Delegate[] _getters;
            private readonly bool[] _active;
            private readonly ColumnBindings _bindings;
            private readonly Action _disposer;
            private bool _disposed;
 
            public override DataViewSchema Schema => _bindings.Schema;
 
            public Cursor(IChannelProvider provider, DataViewRowCursor input, RowToRowMapperTransform parent, bool[] active)
                : base(provider, input)
            {
                var pred = parent.GetActiveOutputColumns(active);
                _getters = parent._mapper.CreateGetters(input, pred, out _disposer);
                _active = active;
                _bindings = parent._bindings;
            }
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                Ch.Check(column.Index < _bindings.Schema.Count);
                return _active[column.Index];
            }
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                Ch.Check(IsColumnActive(column));
 
                bool isSrc;
                int index = _bindings.MapColumnIndex(out isSrc, column.Index);
                if (isSrc)
                    return Input.GetGetter<TValue>(Input.Schema[index]);
 
                Ch.AssertValue(_getters);
                var getter = _getters[index];
                Ch.Assert(getter != null);
                var fn = getter as ValueGetter<TValue>;
                if (fn == null)
                    throw Contracts.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}', " +
                        $"expected type: '{getter.GetType().GetGenericArguments().First()}'.");
                return fn;
            }
 
            protected override void Dispose(bool disposing)
            {
                if (_disposed)
                    return;
                if (disposing)
                    _disposer?.Invoke();
                _disposed = true;
                base.Dispose(disposing);
            }
        }
 
        internal ITransformer GetTransformer()
        {
            return _mapper.GetTransformer();
        }
    }
}