File: DataView\CompositeRowToRowMapper.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// A row-to-row mapper that is the result of a chained application of multiple mappers.
    /// </summary>
    [BestFriend]
    internal sealed class CompositeRowToRowMapper : IRowToRowMapper, IDisposable
    {
        [BestFriend]
        internal IRowToRowMapper[] InnerMappers { get; }
        private static readonly IRowToRowMapper[] _empty = new IRowToRowMapper[0];
 
        public DataViewSchema InputSchema { get; }
        public DataViewSchema OutputSchema { get; }
 
        /// <summary>
        /// Out of a series of mappers, construct a seemingly unitary mapper that is able to apply them in sequence.
        /// </summary>
        /// <param name="inputSchema">The input schema.</param>
        /// <param name="mappers">The sequence of mappers to wrap. An empty or <c>null</c> argument
        /// is legal, and counts as being a no-op application.</param>
        public CompositeRowToRowMapper(DataViewSchema inputSchema, IRowToRowMapper[] mappers)
        {
            Contracts.CheckValue(inputSchema, nameof(inputSchema));
            Contracts.CheckValueOrNull(mappers);
            InnerMappers = Utils.Size(mappers) > 0 ? mappers : _empty;
            InputSchema = inputSchema;
            OutputSchema = Utils.Size(mappers) > 0 ? mappers[mappers.Length - 1].OutputSchema : inputSchema;
        }
 
        /// <summary>
        /// Given a set of columns, return the input columns that are needed to generate those output columns.
        /// </summary>
        IEnumerable<DataViewSchema.Column> IRowToRowMapper.GetDependencies(IEnumerable<DataViewSchema.Column> columnsNeeded)
        {
            for (int i = InnerMappers.Length - 1; i >= 0; --i)
                columnsNeeded = InnerMappers[i].GetDependencies(columnsNeeded);
 
            return columnsNeeded;
        }
 
        DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
        {
            Contracts.CheckValue(input, nameof(input));
            Contracts.CheckValue(activeColumns, nameof(activeColumns));
            Contracts.CheckParam(input.Schema == InputSchema, nameof(input), "Schema did not match original schema");
 
            var activeIndices = activeColumns.Select(c => c.Index).ToArray();
            if (InnerMappers.Length == 0)
            {
                bool differentActive = false;
                for (int c = 0; c < input.Schema.Count; ++c)
                {
                    bool wantsActive = activeIndices.Contains(c);
                    bool isActive = input.IsColumnActive(input.Schema[c]);
                    differentActive |= wantsActive != isActive;
 
                    if (wantsActive && !isActive)
                        throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema[c].Name}' active but it was not.");
                }
                return input;
            }
 
            // For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
            // what we need from them. The last one will just have the input, but the rest will need to be
            // computed based on the dependencies of the next one in the chain.
            IEnumerable<DataViewSchema.Column>[] deps = new IEnumerable<DataViewSchema.Column>[InnerMappers.Length];
            deps[deps.Length - 1] = OutputSchema.Where(c => activeIndices.Contains(c.Index));
            for (int i = deps.Length - 1; i >= 1; --i)
                deps[i - 1] = InnerMappers[i].GetDependencies(deps[i]);
 
            DataViewRow result = input;
            for (int i = 0; i < InnerMappers.Length; ++i)
                result = InnerMappers[i].GetRow(result, deps[i]);
 
            return result;
        }
 
        private sealed class SubsetActive : DataViewRow
        {
            private readonly DataViewRow _row;
            private readonly Func<int, bool> _pred;
 
            public SubsetActive(DataViewRow row, Func<int, bool> pred)
            {
                Contracts.AssertValue(row);
                Contracts.AssertValue(pred);
                _row = row;
                _pred = pred;
            }
 
            public override DataViewSchema Schema => _row.Schema;
            public override long Position => _row.Position;
            public override long Batch => _row.Batch;
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) => _row.GetGetter<TValue>(column);
            public override ValueGetter<DataViewRowId> GetIdGetter() => _row.GetIdGetter();
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column) => _pred(column.Index);
        }
 
        #region IDisposable Support
        private bool _disposed;
 
        public void Dispose()
        {
            if (_disposed)
                return;
 
            foreach (var mapper in InnerMappers)
                (mapper as IDisposable)?.Dispose();
 
            _disposed = true;
        }
        #endregion
    }
}