File: Transforms\RowToRowTransformerBase.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
{
    /// <summary>
    /// Base class for transformer which produce new columns, but doesn't affect existing ones.
    /// </summary>
    public abstract class RowToRowTransformerBase : ITransformer
    {
        [BestFriend]
        private protected readonly IHost Host;
 
        [BestFriend]
        private protected RowToRowTransformerBase(IHost host)
        {
            Contracts.AssertValue(host);
            Host = host;
        }
 
        void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
 
        private protected abstract void SaveModel(ModelSaveContext ctx);
 
        bool ITransformer.IsRowToRowMapper => true;
 
        IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            return new RowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema), MakeRowMapper);
        }
 
        [BestFriend]
        private protected abstract IRowMapper MakeRowMapper(DataViewSchema schema);
 
        public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var mapper = MakeRowMapper(inputSchema);
            return RowToRowMapperTransform.GetOutputSchema(inputSchema, mapper);
        }
 
        public IDataView Transform(IDataView input) => MakeDataTransform(input);
 
        [BestFriend]
        private protected RowToRowMapperTransform MakeDataTransform(IDataView input)
        {
            Host.CheckValue(input, nameof(input));
            return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema), MakeRowMapper);
        }
 
        [BestFriend]
        private protected abstract class MapperBase : IRowMapper
        {
            protected readonly IHost Host;
            protected readonly DataViewSchema InputSchema;
            protected readonly Lazy<DataViewSchema.DetachedColumn[]> OutputColumns;
            private readonly RowToRowTransformerBase _parent;
 
            protected MapperBase(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent)
            {
                Contracts.CheckValue(host, nameof(host));
                Contracts.CheckValue(inputSchema, nameof(inputSchema));
                Host = host;
                InputSchema = inputSchema;
                _parent = parent;
                OutputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);
            }
 
            protected abstract DataViewSchema.DetachedColumn[] GetOutputColumnsCore();
 
            DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => OutputColumns.Value;
 
            public virtual Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
            {
                // REVIEW: it used to be that the mapper's input schema in the constructor was required to be reference-equal to the schema
                // of the input row.
                // It still has to be the same schema, but because we may make a transition from lazy to eager schema, the reference-equality
                // is no longer always possible. So, we relax the assert as below.
                Contracts.Assert(input.Schema == InputSchema);
                int n = OutputColumns.Value.Length;
                var result = new Delegate[n];
                var disposers = new Action[n];
                for (int i = 0; i < n; i++)
                {
                    if (!activeOutput(i))
                        continue;
                    result[i] = MakeGetter(input, i, activeOutput, out disposers[i]);
                }
                if (disposers.Any(x => x != null))
                {
                    disposer = () =>
                    {
                        foreach (var act in disposers)
                            act();
                    };
                }
                else
                    disposer = null;
                return result;
            }
 
            protected abstract Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer);
 
            Func<int, bool> IRowMapper.GetDependencies(Func<int, bool> activeOutput)
                => GetDependenciesCore(activeOutput);
 
            [BestFriend]
            private protected abstract Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput);
 
            void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
 
            private protected abstract void SaveModel(ModelSaveContext ctx);
 
            public ITransformer GetTransformer() => _parent;
        }
    }
}