File: Scorers\RowToRowScorerBase.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// Base class for scoring rows independently. This assumes that all columns produced by the
    /// underlying <see cref="ISchemaBoundRowMapper"/> should be exposed, as well as zero or more
    /// "derived" columns.
    /// </summary>
    internal abstract class RowToRowScorerBase : RowToRowMapperTransformBase, IDataScorerTransform
    {
        private static readonly FuncStaticMethodInfo1<DataViewRow, int, Delegate> _getGetterFromRowMethodInfo
            = new FuncStaticMethodInfo1<DataViewRow, int, Delegate>(GetGetterFromRow<int>);
 
        [BestFriend]
        private protected abstract class BindingsBase : ScorerBindingsBase
        {
            public readonly ISchemaBoundRowMapper RowMapper;
 
            protected BindingsBase(DataViewSchema schema, ISchemaBoundRowMapper mapper, string suffix, bool user, params string[] namesDerived)
                : base(schema, mapper, suffix, user, namesDerived)
            {
                RowMapper = mapper;
            }
        }
 
        [BestFriend]
        private protected readonly ISchemaBindableMapper Bindable;
 
        [BestFriend]
        private protected RowToRowScorerBase(IHostEnvironment env, IDataView input, string registrationName, ISchemaBindableMapper bindable)
            : base(env, registrationName, input)
        {
            Contracts.AssertValue(bindable);
            Bindable = bindable;
        }
 
        [BestFriend]
        private protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            ctx.LoadModel<ISchemaBindableMapper, SignatureLoadModel>(host, out Bindable, "SchemaBindableMapper");
        }
 
        private protected sealed override void SaveModel(ModelSaveContext ctx)
        {
            Contracts.AssertValue(ctx);
            ctx.CheckAtModel();
            ctx.SaveModel(Bindable, "SchemaBindableMapper");
            SaveCore(ctx);
        }
 
        /// <summary>
        /// The main save method handles saving the _bindable. This should do everything else.
        /// </summary>
        [BestFriend]
        private protected abstract void SaveCore(ModelSaveContext ctx);
 
        private protected abstract IDataTransform ApplyToDataCore(IHostEnvironment env, IDataView newSource);
 
        IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource)
            => ApplyToDataCore(env, newSource);
        internal IDataView ApplyToData(IHostEnvironment env, IDataView newSource)
            => ApplyToDataCore(env, newSource);
 
        /// <summary>
        /// Derived classes provide the specific bindings object.
        /// </summary>
        [BestFriend]
        private protected abstract BindingsBase GetBindings();
 
        /// <summary>
        /// Produces the set of active columns for the scorer (as a bool[] of length bindings.ColumnCount).
        /// </summary>
        private static bool[] GetActive(BindingsBase bindings,
            IEnumerable<DataViewSchema.Column> columns,
            out IEnumerable<DataViewSchema.Column> inputColumns,
            out IEnumerable<DataViewSchema.Column> activeRowMapperCols)
        {
            var active = bindings.GetActive(columns);
            Contracts.Assert(active.Length == bindings.ColumnCount);
 
            var activeInput = bindings.GetActiveInput(columns);
            Contracts.Assert(activeInput.Length == bindings.Input.Count);
 
            // Get a predicate that determines which Mapper outputs are active.
            var predicateMapper = bindings.GetActiveMapperColumns(active);
            Func<int, bool> localMapper = predicateMapper;
 
            // Get the active output columns
            activeRowMapperCols = bindings.RowMapper.OutputSchema.Where(c => predicateMapper(c.Index));
            var colsInputForMapper = bindings.RowMapper.GetDependenciesForNewColumns(activeRowMapperCols);
 
            var activeInCols = bindings.Input.Where(c => c.Index < activeInput.Length && activeInput[c.Index]);
            inputColumns = activeInCols.Union(colsInputForMapper);
 
            return active;
        }
 
        /// <summary>
        /// This produces either "true" or "null" according to whether <see cref="WantParallelCursors"/>
        /// returns true or false. Note that this will never return false. Any derived class
        /// must support (but not necessarily prefer) parallel cursors.
        /// </summary>
        protected sealed override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
        {
            Host.AssertValue(predicate);
            if (WantParallelCursors(predicate))
                return true;
            return null;
        }
 
        /// <summary>
        /// This should return true iff parallel cursors are advantageous. Typically, this
        /// will return true iff some columns added by this scorer are active.
        /// </summary>
        protected abstract bool WantParallelCursors(Func<int, bool> predicate);
 
        protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            Contracts.AssertValueOrNull(rand);
 
            var bindings = GetBindings();
            var active = GetActive(bindings,
                columnsNeeded,
                out IEnumerable<DataViewSchema.Column> inputCols,
                out IEnumerable<DataViewSchema.Column> activeMapperColumns);
            var input = Source.GetRowCursor(inputCols, rand);
            return new Cursor(Host, this, input, active, activeMapperColumns);
        }
 
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.CheckValueOrNull(rand);
 
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
 
            var bindings = GetBindings();
            var active = GetActive(bindings, columnsNeeded,
               out IEnumerable<DataViewSchema.Column> inputCols,
               out IEnumerable<DataViewSchema.Column> activeMapperColumns);
            var inputs = Source.GetRowCursorSet(inputCols, n, rand);
            Contracts.AssertNonEmpty(inputs);
 
            if (inputs.Length == 1 && n > 1 && WantParallelCursors(predicate) && (Source.GetRowCount() ?? int.MaxValue) > n)
                inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n);
            Contracts.AssertNonEmpty(inputs);
 
            var cursors = new DataViewRowCursor[inputs.Length];
            for (int i = 0; i < inputs.Length; i++)
                cursors[i] = new Cursor(Host, this, inputs[i], active, activeMapperColumns);
            return cursors;
        }
 
        protected override Delegate[] CreateGetters(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns, out Action disp)
        {
            var bindings = GetBindings();
            IEnumerable<DataViewSchema.Column> inputColumns;
            GetActive(bindings, activeColumns, out inputColumns, out IEnumerable<DataViewSchema.Column> activeMapperColumns);
            var output = bindings.RowMapper.GetRow(input, activeMapperColumns);
            var activeIndices = new HashSet<int>(activeColumns.Select(c => c.Index));
            Func<int, bool> activeInfos = iinfo => activeIndices.Contains(bindings.MapIinfoToCol(iinfo));
            disp = output.Dispose;
            return GetGetters(output, activeInfos);
        }
 
        protected override IEnumerable<DataViewSchema.Column> GetDependenciesCore(IEnumerable<DataViewSchema.Column> columns)
        {
            var bindings = GetBindings();
            GetActive(bindings, columns, out IEnumerable<DataViewSchema.Column> inputColumns, out IEnumerable<DataViewSchema.Column> predicateMapper);
            return inputColumns;
        }
 
        /// <summary>
        /// Create and fill an array of getters of size InfoCount. The indices of the non-null entries in the
        /// result should be exactly those for which predicate(iinfo) is true.
        /// </summary>
        protected abstract Delegate[] GetGetters(DataViewRow output, Func<int, bool> predicate);
 
        protected static Delegate[] GetGettersFromRow(DataViewRow row, Func<int, bool> predicate)
        {
            Contracts.AssertValue(row);
            Contracts.AssertValue(predicate);
 
            var getters = new Delegate[row.Schema.Count];
            for (int col = 0; col < getters.Length; col++)
            {
                if (predicate(col))
                    getters[col] = GetGetterFromRow(row, col);
            }
            return getters;
        }
 
        protected static Delegate GetGetterFromRow(DataViewRow row, int col)
        {
            Contracts.AssertValue(row);
            Contracts.Assert(0 <= col && col < row.Schema.Count);
            Contracts.Assert(row.IsColumnActive(row.Schema[col]));
 
            var type = row.Schema[col].Type;
            return Utils.MarshalInvoke(_getGetterFromRowMethodInfo, type.RawType, row, col);
        }
 
        protected static ValueGetter<T> GetGetterFromRow<T>(DataViewRow output, int col)
        {
            Contracts.AssertValue(output);
            Contracts.Assert(0 <= col && col < output.Schema.Count);
            Contracts.Assert(output.IsColumnActive(output.Schema[col]));
            return output.GetGetter<T>(output.Schema[col]);
        }
 
        protected override int MapColumnIndex(out bool isSrc, int col)
        {
            var bindings = GetBindings();
            return bindings.MapColumnIndex(out isSrc, col);
        }
 
        private sealed class Cursor : SynchronizedCursorBase
        {
            private readonly BindingsBase _bindings;
            private readonly bool[] _active;
            private readonly Delegate[] _getters;
            private readonly DataViewRow _output;
            private bool _disposed;
 
            public override DataViewSchema Schema { get; }
 
            public Cursor(IChannelProvider provider, RowToRowScorerBase parent, DataViewRowCursor input, bool[] active, IEnumerable<DataViewSchema.Column> activeMapperColumns)
                : base(provider, input)
            {
                Ch.AssertValue(parent);
                Ch.AssertValue(active);
                Ch.AssertValue(activeMapperColumns);
 
                _bindings = parent.GetBindings();
                Schema = parent.OutputSchema;
                Ch.Assert(active.Length == _bindings.ColumnCount);
                _active = active;
 
                _output = _bindings.RowMapper.GetRow(input, activeMapperColumns);
                try
                {
                    Ch.Assert(_output.Schema == _bindings.RowMapper.OutputSchema);
                    _getters = parent.GetGetters(_output, iinfo => active[_bindings.MapIinfoToCol(iinfo)]);
                }
                catch (Exception)
                {
                    _output.Dispose();
                    throw;
                }
            }
 
            protected override void Dispose(bool disposing)
            {
                if (_disposed)
                    return;
                if (disposing)
                    _output.Dispose();
                _disposed = true;
                base.Dispose(disposing);
            }
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                Ch.Check(column.Index < _bindings.ColumnCount);
                return _active[column.Index];
            }
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                Ch.Check(IsColumnActive(column));
 
                bool isSrc;
                int index = _bindings.MapColumnIndex(out isSrc, column.Index);
                if (isSrc)
                    return Input.GetGetter<TValue>(Input.Schema[index]);
 
                Ch.AssertValue(_getters);
                var getter = _getters[index];
                Ch.Assert(getter != null);
                var fn = getter as ValueGetter<TValue>;
                if (fn == null)
                    throw Ch.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}', " +
                            $"expected type: '{getter.GetType().GetGenericArguments().First()}'.");
                return fn;
            }
        }
    }
 
    [BestFriend]
    internal abstract class ScorerArgumentsBase
    {
        // Output columns.
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "Output column: The suffix to append to the default column names", ShortName = "ex")]
        public string Suffix;
    }
 
    /// <summary>
    /// Base bindings for a scorer based on an <see cref="ISchemaBoundMapper"/>. This assumes that input schema columns
    /// are echoed, followed by zero or more derived columns, followed by the mapper generated columns.
    /// The names of the derived columns and mapper generated columns have an optional suffix appended.
    /// </summary>
    [BestFriend]
    internal abstract class ScorerBindingsBase : ColumnBindingsBase
    {
        /// <summary>
        /// The schema bound mapper.
        /// </summary>
        public readonly ISchemaBoundMapper Mapper;
 
        /// <summary>
        /// The column name suffix. Non-null, but may be empty.
        /// </summary>
        public readonly string Suffix;
 
        /// <summary>
        /// The number of derived columns. InfoCount == DerivedColumnCount + Mapper.OutputSchema.ColumnCount.
        /// </summary>
        public readonly int DerivedColumnCount;
 
        private readonly uint _crtScoreSet;
        private readonly AnnotationUtils.AnnotationGetter<uint> _getScoreColumnSetId;
 
        protected ScorerBindingsBase(DataViewSchema input, ISchemaBoundMapper mapper, string suffix, bool user, params string[] namesDerived)
            : base(input, user, GetOutputNames(mapper, suffix, namesDerived))
        {
            Contracts.AssertValue(mapper);
            Contracts.AssertValueOrNull(suffix);
            Contracts.AssertValue(namesDerived);
 
            Mapper = mapper;
            DerivedColumnCount = namesDerived.Length;
            Suffix = suffix ?? "";
 
            int c;
            var max = input.GetMaxAnnotationKind(out c, AnnotationUtils.Kinds.ScoreColumnSetId);
            _crtScoreSet = checked(max + 1);
            _getScoreColumnSetId = GetScoreColumnSetId;
        }
 
        private static string[] GetOutputNames(ISchemaBoundMapper mapper, string suffix, string[] namesDerived)
        {
            Contracts.AssertValue(mapper);
            Contracts.AssertValueOrNull(suffix);
            Contracts.AssertValue(namesDerived);
 
            var schema = mapper.OutputSchema;
            int count = namesDerived.Length + schema.Count;
            var res = new string[count];
            int dst = 0;
            for (int i = 0; i < namesDerived.Length; i++)
                res[dst++] = namesDerived[i] + suffix;
            for (int i = 0; i < schema.Count; i++)
                res[dst++] = schema[i].Name + suffix;
            Contracts.Assert(dst == count);
            return res;
        }
 
        protected static KeyValuePair<RoleMappedSchema.ColumnRole, string>[] LoadBaseInfo(
            ModelLoadContext ctx, out string suffix)
        {
            // *** Binary format ***
            // int: id of the suffix
            // int: the number of input column roles
            // for each input column:
            //   int: id of the column role
            //   int: id of the column name
            suffix = ctx.LoadString();
 
            var count = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(count >= 0);
 
            var columns = new KeyValuePair<RoleMappedSchema.ColumnRole, string>[count];
            for (int i = 0; i < count; i++)
            {
                var role = ctx.LoadNonEmptyString();
                var name = ctx.LoadNonEmptyString();
                columns[i] = RoleMappedSchema.CreatePair(role, name);
            }
 
            return columns;
        }
 
        protected void SaveBase(ModelSaveContext ctx)
        {
            Contracts.AssertValue(ctx);
 
            // *** Binary format ***
            // int: id of the suffix
            // int: the number of input column roles
            // for each input column:
            //   int: id of the column role
            //   int: id of the column name
            ctx.SaveString(Suffix);
 
            var cols = Mapper.GetInputColumnRoles().ToArray();
            ctx.Writer.Write(cols.Length);
            foreach (var kvp in cols)
            {
                ctx.SaveNonEmptyString(kvp.Key.Value);
                ctx.SaveNonEmptyString(kvp.Value);
            }
        }
 
        internal abstract void SaveModel(ModelSaveContext ctx);
 
        protected override DataViewType GetColumnTypeCore(int iinfo)
        {
            Contracts.Assert(DerivedColumnCount <= iinfo && iinfo < InfoCount);
            return Mapper.OutputSchema[iinfo - DerivedColumnCount].Type;
        }
 
        protected override IEnumerable<KeyValuePair<string, DataViewType>> GetAnnotationTypesCore(int iinfo)
        {
            Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
 
            yield return AnnotationUtils.ScoreColumnSetIdType.GetPair(AnnotationUtils.Kinds.ScoreColumnSetId);
            if (iinfo < DerivedColumnCount)
                yield break;
            foreach (var pair in Mapper.OutputSchema[iinfo - DerivedColumnCount].Annotations.Schema.Select(c => new KeyValuePair<string, DataViewType>(c.Name, c.Type)))
                yield return pair;
        }
 
        protected override DataViewType GetAnnotationTypeCore(string kind, int iinfo)
        {
            Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
            if (kind == AnnotationUtils.Kinds.ScoreColumnSetId)
                return AnnotationUtils.ScoreColumnSetIdType;
            if (iinfo < DerivedColumnCount)
                return null;
            return Mapper.OutputSchema[iinfo - DerivedColumnCount].Annotations.Schema.GetColumnOrNull(kind)?.Type;
        }
 
        protected override void GetAnnotationCore<TValue>(string kind, int iinfo, ref TValue value)
        {
            Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
            switch (kind)
            {
                case AnnotationUtils.Kinds.ScoreColumnSetId:
                    _getScoreColumnSetId.Marshal(iinfo, ref value);
                    break;
                default:
                    if (iinfo < DerivedColumnCount)
                        throw AnnotationUtils.ExceptGetAnnotation();
                    Mapper.OutputSchema[iinfo - DerivedColumnCount].Annotations.GetValue(kind, ref value);
                    break;
            }
        }
 
        /// <summary>
        /// Returns a predicate indicating which Mapper columns are active based on the active scorer columns.
        /// This is virtual so scorers with computed columns can do the right thing.
        /// </summary>
        public virtual Func<int, bool> GetActiveMapperColumns(bool[] active)
        {
            Contracts.AssertValue(active);
            Contracts.Assert(active.Length == ColumnCount);
 
            return
                col =>
                {
                    Contracts.Assert(0 <= col && col < Mapper.OutputSchema.Count);
                    return 0 <= col && col < Mapper.OutputSchema.Count &&
                        active[MapIinfoToCol(col + DerivedColumnCount)];
                };
        }
 
        protected void GetScoreColumnSetId(int iinfo, ref uint dst)
        {
            Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
            dst = _crtScoreSet;
        }
    }
}