File: MatrixFactorizationPredictor.cs
Web Access
Project: src\src\Microsoft.ML.Recommender\Microsoft.ML.Recommender.csproj (Microsoft.ML.Recommender)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Recommender;
using Microsoft.ML.Recommender.Internal;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.Recommender;
 
[assembly: LoadableClass(typeof(MatrixFactorizationModelParameters), null, typeof(SignatureLoadModel), "Matrix Factorization Predictor Executor", MatrixFactorizationModelParameters.LoaderSignature)]
 
[assembly: LoadableClass(typeof(MatrixFactorizationPredictionTransformer), typeof(MatrixFactorizationPredictionTransformer),
    null, typeof(SignatureLoadModel), "", MatrixFactorizationPredictionTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Trainers.Recommender
{
    /// <summary>
    /// Model parameters for <see cref="MatrixFactorizationTrainer"/>.
    /// </summary>
    /// <remarks>
    /// <see cref="MatrixFactorizationModelParameters"/> stores two factor matrices, P and Q, for approximating the training matrix, R, by P * Q,
    /// where * is a matrix multiplication. This model expects two inputs, row index and column index, and produces the (approximated)
    /// value at the location specified by the two inputs in R. More specifically, if input row and column indices are u and v, respectively.
    /// The output (a scalar) would be the inner product of the u-th row in P and the v-th column in Q.
    /// </remarks>
    public sealed class MatrixFactorizationModelParameters : IPredictor, ICanSaveModel, ICanSaveInTextFormat, ISchemaBindableMapper
    {
        internal const string LoaderSignature = "MFPredictor";
        internal const string RegistrationName = "MatrixFactorizationPredictor";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "FAFAMAPD",
                // verWrittenCur: 0x00010001, // Initial
                verWrittenCur: 0x00010002, // Removed Min in KeyType
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(MatrixFactorizationModelParameters).Assembly.FullName);
        }
        private const uint VersionNoMinCount = 0x00010002;
 
        private readonly IHost _host;
 
        ///<summary> The number of rows.</summary>
        public readonly int NumberOfRows;
 
        ///<summary> The number of columns.</summary>
        public readonly int NumberOfColumns;
 
        ///<summary> The rank of the factor matrices.</summary>
        public readonly int ApproximationRank;
 
        /// <summary>
        /// Left approximation matrix
        /// </summary>
        /// <remarks>
        /// This is two dimensional matrix with size of <see cref="NumberOfRows"/> * <see cref="ApproximationRank"/> flattened into one-dimensional matrix.
        /// Row by row.
        /// </remarks>
        public IReadOnlyList<float> LeftFactorMatrix => _leftFactorMatrix;
 
        private readonly float[] _leftFactorMatrix;
        /// <summary>
        /// Right approximation matrix
        /// </summary>
        /// <remarks>
        /// This is two dimensional matrix with size of <see cref="ApproximationRank"/> * <see cref="NumberOfColumns"/> flattened into one-dimensional matrix.
        /// Row by row.
        /// </remarks>
        public IReadOnlyList<float> RightFactorMatrix => _rightFactorMatrix;
 
        private readonly float[] _rightFactorMatrix;
 
        PredictionKind IPredictor.PredictionKind => PredictionKind.Recommendation;
 
        private DataViewType OutputType => NumberDataViewType.Single;
 
        internal DataViewType MatrixColumnIndexType { get; }
        internal DataViewType MatrixRowIndexType { get; }
 
        internal MatrixFactorizationModelParameters(IHostEnvironment env, SafeTrainingAndModelBuffer buffer, KeyDataViewType matrixColumnIndexType, KeyDataViewType matrixRowIndexType)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            _host.Assert(matrixColumnIndexType.RawType == typeof(uint));
            _host.Assert(matrixRowIndexType.RawType == typeof(uint));
            _host.CheckValue(buffer, nameof(buffer));
            _host.CheckValue(matrixColumnIndexType, nameof(matrixColumnIndexType));
            _host.CheckValue(matrixRowIndexType, nameof(matrixRowIndexType));
            buffer.Get(out NumberOfRows, out NumberOfColumns, out ApproximationRank, out var leftFactorMatrix, out var rightFactorMatrix);
            _leftFactorMatrix = leftFactorMatrix;
            _rightFactorMatrix = rightFactorMatrix;
            _host.Assert(NumberOfColumns == matrixColumnIndexType.GetCountAsInt32(_host));
            _host.Assert(NumberOfRows == matrixRowIndexType.GetCountAsInt32(_host));
            _host.Assert(_leftFactorMatrix.Length == NumberOfRows * ApproximationRank);
            _host.Assert(_rightFactorMatrix.Length == ApproximationRank * NumberOfColumns);
 
            MatrixColumnIndexType = matrixColumnIndexType;
            MatrixRowIndexType = matrixRowIndexType;
        }
 
        private MatrixFactorizationModelParameters(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            // *** Binary format ***
            // int: number of rows (m), the limit on row
            // int: number of columns (n), the limit on column
            // int: rank of factor matrices (k)
            // float[m * k]: the left factor matrix
            // float[k * n]: the right factor matrix
 
            NumberOfRows = ctx.Reader.ReadInt32();
            _host.CheckDecode(NumberOfRows > 0);
            if (ctx.Header.ModelVerWritten < VersionNoMinCount)
            {
                ulong mMin = ctx.Reader.ReadUInt64();
                // We no longer support non zero Min for KeyType.
                _host.CheckDecode(mMin == 0);
                _host.CheckDecode((ulong)NumberOfRows <= ulong.MaxValue - mMin);
            }
            NumberOfColumns = ctx.Reader.ReadInt32();
            _host.CheckDecode(NumberOfColumns > 0);
            if (ctx.Header.ModelVerWritten < VersionNoMinCount)
            {
                ulong nMin = ctx.Reader.ReadUInt64();
                // We no longer support non zero Min for KeyType.
                _host.CheckDecode(nMin == 0);
                _host.CheckDecode((ulong)NumberOfColumns <= ulong.MaxValue - nMin);
            }
            ApproximationRank = ctx.Reader.ReadInt32();
            _host.CheckDecode(ApproximationRank > 0);
 
            _leftFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(NumberOfRows * ApproximationRank));
            _rightFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(NumberOfColumns * ApproximationRank));
 
            MatrixColumnIndexType = new KeyDataViewType(typeof(uint), NumberOfColumns);
            MatrixRowIndexType = new KeyDataViewType(typeof(uint), NumberOfRows);
        }
 
        /// <summary>
        /// Load model from the given context
        /// </summary>
        internal static MatrixFactorizationModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            return new MatrixFactorizationModelParameters(env, ctx);
        }
 
        /// <summary>
        /// Save model to the given context
        /// </summary>
        void ICanSaveModel.Save(ModelSaveContext ctx)
        {
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // int: number of rows (m), the limit on row
            // int: number of columns (n), the limit on column
            // int: rank of factor matrices (k)
            // float[m * k]: the left factor matrix
            // float[k * n]: the right factor matrix
 
            _host.Check(NumberOfRows > 0, "Number of rows must be positive");
            _host.Check(NumberOfColumns > 0, "Number of columns must be positive");
            _host.Check(ApproximationRank > 0, "Number of latent factors must be positive");
            ctx.Writer.Write(NumberOfRows);
            ctx.Writer.Write(NumberOfColumns);
            ctx.Writer.Write(ApproximationRank);
            _host.Check(Utils.Size(_leftFactorMatrix) == NumberOfRows * ApproximationRank, "Unexpected matrix size of a factor matrix (matrix P in LIBMF paper)");
            _host.Check(Utils.Size(_rightFactorMatrix) == NumberOfColumns * ApproximationRank, "Unexpected matrix size of a factor matrix (matrix Q in LIBMF paper)");
            Utils.WriteSinglesNoCount(ctx.Writer, _leftFactorMatrix);
            Utils.WriteSinglesNoCount(ctx.Writer, _rightFactorMatrix);
        }
 
        /// <summary>
        /// Save the trained matrix factorization model (two factor matrices) in text format
        /// </summary>
        void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
        {
            writer.WriteLine("# Imputed matrix is P * Q'");
            writer.WriteLine("# P in R^({0} x {1}), rows correspond to Y item", NumberOfRows, ApproximationRank);
            for (int i = 0; i < _leftFactorMatrix.Length; ++i)
            {
                writer.Write(_leftFactorMatrix[i].ToString("G"));
                if (i % ApproximationRank == ApproximationRank - 1)
                    writer.WriteLine();
                else
                    writer.Write('\t');
            }
            writer.WriteLine("# Q in R^({0} x {1}), rows correspond to X item", NumberOfColumns, ApproximationRank);
            for (int i = 0; i < _rightFactorMatrix.Length; ++i)
            {
                writer.Write(_rightFactorMatrix[i].ToString("G"));
                if (i % ApproximationRank == ApproximationRank - 1)
                    writer.WriteLine();
                else
                    writer.Write('\t');
            }
        }
 
        private ValueGetter<float> GetGetter(ValueGetter<uint> matrixColumnIndexGetter, ValueGetter<uint> matrixRowIndexGetter)
        {
            _host.AssertValue(matrixColumnIndexGetter);
            _host.AssertValue(matrixRowIndexGetter);
 
            uint matrixColumnIndex = 0;
            uint matrixRowIndex = 0;
 
            var mapper = GetMapper<uint, uint, float>();
            ValueGetter<float> del =
                (ref float value) =>
                {
                    matrixColumnIndexGetter(ref matrixColumnIndex);
                    matrixRowIndexGetter(ref matrixRowIndex);
                    mapper(in matrixColumnIndex, ref matrixRowIndex, ref value);
                };
            return del;
        }
 
        /// <summary>
        /// Create the mapper required by matrix factorization's predictor. That mapper maps two
        /// index inputs (e.g., row index and column index) to an approximated value located by the
        /// two indexes in the training matrix. In recommender system where the training matrix stores
        /// ratings from users to items, the mappers maps user ID and item ID to the rating of that
        /// item given by the user.
        /// </summary>
        private ValueMapper<TMatrixColumnIndexIn, TMatrixRowIndexIn, TOut> GetMapper<TMatrixColumnIndexIn, TMatrixRowIndexIn, TOut>()
        {
            string msg = null;
            msg = "Invalid " + nameof(TMatrixColumnIndexIn) + " in GetMapper: " + typeof(TMatrixColumnIndexIn);
            _host.Check(typeof(TMatrixColumnIndexIn) == typeof(uint), msg);
 
            msg = "Invalid " + nameof(TMatrixRowIndexIn) + " in GetMapper: " + typeof(TMatrixRowIndexIn);
            _host.Check(typeof(TMatrixRowIndexIn) == typeof(uint), msg);
 
            msg = "Invalid " + nameof(TOut) + " in GetMapper: " + typeof(TOut);
            _host.Check(typeof(TOut) == typeof(float), msg);
 
            ValueMapper<uint, uint, float> mapper = MapperCore;
            return mapper as ValueMapper<TMatrixColumnIndexIn, TMatrixRowIndexIn, TOut>;
        }
 
        /// <summary>
        /// Compute the (approximated) value at the <paramref name="srcCol"/>-th column and the
        /// <paramref name="srcRow"/>-th row. Notice that both of <paramref name="srcCol"/> and
        /// <paramref name="srcRow"/> are 1-based indexes, so the first row/column index is 1.
        /// The reason for having 1-based indexing system is that key-valued getter in ML.NET returns
        /// 1 for its first value and 0 is used to denote missing value.
        /// </summary>
        /// <param name="srcCol">1-based column index.</param>
        /// <param name="srcRow">1-based row index.</param>
        /// <param name="dst">value at the <paramref name="srcCol"/>-th column and the <paramref name="srcRow"/>-th row.</param>
        private void MapperCore(in uint srcCol, ref uint srcRow, ref float dst)
        {
            // REVIEW: The key-type version a bit more "strict" than the predictor
            // version, since the predictor version can't know the maximum bound during
            // training. For higher-than-expected values, the predictor version would return
            // 0, rather than NaN as we do here. It is in my mind an open question as to what
            // is actually correct.
            if (srcRow == 0 || srcRow > NumberOfRows || srcCol == 0 || srcCol > NumberOfColumns)
            {
                dst = float.NaN;
                return;
            }
 
            // The index system in the LIBMF (the library trains the model) is 0-based, so we need to deduct one
            // from 1-based indexes returned by ML.NET's key-valued getters. We also throw when seeing 0 because
            // missing index is not meaningful to the trained model.
            dst = Score((int)(srcCol - 1), (int)(srcRow - 1));
        }
 
        /// <summary>
        /// Compute the (approximated) value at the <paramref name="columnIndex"/>-th column and the
        /// <paramref name="rowIndex"/>-th row. Notice that, in contrast to <see cref="MapperCore"/>,
        /// both of <paramref name="columnIndex"/> and <paramref name="rowIndex"/> are 0-based indexes,
        /// so the first row/column index is 0.
        /// </summary>
        /// <param name="columnIndex">0-based column index.</param>
        /// <param name="rowIndex">0-based row index.</param>
        private float Score(int columnIndex, int rowIndex)
        {
            _host.Assert(0 <= rowIndex && rowIndex < NumberOfRows);
            _host.Assert(0 <= columnIndex && columnIndex < NumberOfColumns);
            float score = 0;
            // Starting position of the rowIndex-th row in the left factor factor matrix
            int rowOffset = rowIndex * ApproximationRank;
            // Starting position of the columnIndex-th column in the right factor factor matrix
            int columnOffset = columnIndex * ApproximationRank;
            for (int i = 0; i < ApproximationRank; i++)
                score += _leftFactorMatrix[rowOffset + i] * _rightFactorMatrix[columnOffset + i];
            return score;
        }
 
        /// <summary>
        /// Create a row mapper based on regression scorer. Because matrix factorization predictor maps a tuple of a row ID (u) and a column ID (v)
        /// to the expected numerical value at the u-th row and the v-th column in the considered matrix, it is essentially a regressor.
        /// </summary>
        ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
        {
            Contracts.AssertValue(env);
            env.AssertValue(schema);
            return new RowMapper(env, this, schema, ScoreSchemaFactory.Create(OutputType, AnnotationUtils.Const.ScoreColumnKind.Regression));
        }
 
        private sealed class RowMapper : ISchemaBoundRowMapper
        {
            private readonly MatrixFactorizationModelParameters _parent;
            // The tail "ColumnIndex" means the column index in IDataView
            private readonly int _matrixColumnIndexColumnIndex;
            private readonly int _matrixRowIndexCololumnIndex;
            // The tail "ColumnName" means the column name in IDataView
            private readonly string _matrixColumnIndexColumnName;
            private readonly string _matrixRowIndexColumnName;
            private readonly IHostEnvironment _env;
            public DataViewSchema InputSchema => InputRoleMappedSchema.Schema;
            public DataViewSchema OutputSchema { get; }
 
            public RoleMappedSchema InputRoleMappedSchema { get; }
 
            public RowMapper(IHostEnvironment env, MatrixFactorizationModelParameters parent, RoleMappedSchema schema, DataViewSchema outputSchema)
            {
                Contracts.AssertValue(parent);
                _env = env;
                _parent = parent;
 
                // Check role of matrix column index
                var matrixColumnList = schema.GetColumns(RecommenderUtils.MatrixColumnIndexKind);
                string msg = $"'{RecommenderUtils.MatrixColumnIndexKind}' column doesn't exist or not unique";
                _env.Check(Utils.Size(matrixColumnList) == 1, msg);
 
                // Check role of matrix row index
                var matrixRowList = schema.GetColumns(RecommenderUtils.MatrixRowIndexKind);
                msg = $"'{RecommenderUtils.MatrixRowIndexKind}' column doesn't exist or not unique";
                _env.Check(Utils.Size(matrixRowList) == 1, msg);
 
                _matrixColumnIndexColumnName = matrixColumnList[0].Name;
                _matrixColumnIndexColumnIndex = matrixColumnList[0].Index;
 
                _matrixRowIndexColumnName = matrixRowList[0].Name;
                _matrixRowIndexCololumnIndex = matrixRowList[0].Index;
 
                CheckInputSchema(schema.Schema, _matrixColumnIndexColumnIndex, _matrixRowIndexCololumnIndex);
                InputRoleMappedSchema = schema;
                OutputSchema = outputSchema;
            }
 
            /// <summary>
            /// Given a set of columns, return the input columns that are needed to generate those output columns.
            /// </summary>
            public IEnumerable<DataViewSchema.Column> GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns)
            {
                if (dependingColumns.Count() == 0)
                    return Enumerable.Empty<DataViewSchema.Column>();
 
                return InputSchema.Where(col => col.Index == _matrixColumnIndexColumnIndex || col.Index == _matrixRowIndexCololumnIndex);
            }
 
            public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
            {
                yield return RecommenderUtils.MatrixColumnIndexKind.Bind(_matrixColumnIndexColumnName);
                yield return RecommenderUtils.MatrixRowIndexKind.Bind(_matrixRowIndexColumnName);
            }
 
            private void CheckInputSchema(DataViewSchema schema, int matrixColumnIndexCol, int matrixRowIndexCol)
            {
                // See if matrix-column-index role's type matches the one expected in the trained predictor
                var type = schema[matrixColumnIndexCol].Type;
                string msg = string.Format("Input column index type '{0}' incompatible with predictor's column index type '{1}'", type, _parent.MatrixColumnIndexType);
                _env.CheckParam(type.Equals(_parent.MatrixColumnIndexType), nameof(schema), msg);
 
                // See if matrix-column-index  role's type matches the one expected in the trained predictor
                type = schema[matrixRowIndexCol].Type;
                msg = string.Format("Input row index type '{0}' incompatible with predictor' row index type '{1}'", type, _parent.MatrixRowIndexType);
                _env.CheckParam(type.Equals(_parent.MatrixRowIndexType), nameof(schema), msg);
            }
 
            private Delegate[] CreateGetter(DataViewRow input, bool[] active)
            {
                _env.CheckValue(input, nameof(input));
                _env.Assert(Utils.Size(active) == OutputSchema.Count);
 
                var getters = new Delegate[1];
                if (active[0])
                {
                    // First check if expected columns are ok and then create getters to acccess those columns' values.
                    CheckInputSchema(input.Schema, _matrixColumnIndexColumnIndex, _matrixRowIndexCololumnIndex);
                    var matrixColumnIndexGetter = RowCursorUtils.GetGetterAs<uint>(NumberDataViewType.UInt32, input, _matrixColumnIndexColumnIndex);
                    var matrixRowIndexGetter = RowCursorUtils.GetGetterAs<uint>(NumberDataViewType.UInt32, input, _matrixRowIndexCololumnIndex);
 
                    // Assign the getter of the prediction score. It maps a pair of matrix column index and matrix row index to a scalar.
                    getters[0] = _parent.GetGetter(matrixColumnIndexGetter, matrixRowIndexGetter);
                }
                return getters;
            }
 
            DataViewRow ISchemaBoundRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
            {
                var activeArray = Utils.BuildArray(OutputSchema.Count, activeColumns);
                var getters = CreateGetter(input, activeArray);
                return new SimpleRow(OutputSchema, input, getters);
            }
 
            public ISchemaBindableMapper Bindable => _parent;
        }
    }
 
    /// <summary>
    /// Trains a <see cref="MatrixFactorizationModelParameters"/>. It factorizes the training matrix into the product of two low-rank matrices.
    /// </summary>
    public sealed class MatrixFactorizationPredictionTransformer : PredictionTransformerBase<MatrixFactorizationModelParameters>
    {
        internal const string LoaderSignature = "MaFactPredXf";
        internal string MatrixColumnIndexColumnName { get; }
        internal string MatrixRowIndexColumnName { get; }
        internal DataViewType MatrixColumnIndexColumnType { get; }
        internal DataViewType MatrixRowIndexColumnType { get; }
 
        /// <summary>
        /// Build a transformer based on matrix factorization predictor (model) and the input schema (trainSchema). The created
        /// transformer can only transform IDataView objects compatible to the input schema; that is, that IDataView must contain
        /// columns specified by <see cref="MatrixColumnIndexColumnName"/>, <see cref="MatrixColumnIndexColumnType"/>, <see cref="MatrixRowIndexColumnName"/>, and <see cref="MatrixRowIndexColumnType"></see>.
        /// The output column is "Score" by default but user can append a string to it.
        /// </summary>
        /// <param name="env">Environment object for showing information</param>
        /// <param name="model">The model trained by one of the training functions in <see cref="MatrixFactorizationTrainer"/></param>
        /// <param name="trainSchema">Targeted schema that containing columns named as xColumnName</param>
        /// <param name="matrixColumnIndexColumnName">The name of the column used as role <see cref="RecommenderUtils.MatrixColumnIndexKind"/> in matrix factorization world</param>
        /// <param name="matrixRowIndexColumnName">The name of the column used as role <see cref="RecommenderUtils.MatrixRowIndexKind"/> in matrix factorization world</param>
        /// <param name="scoreColumnNameSuffix">A string attached to the output column name of this transformer</param>
        internal MatrixFactorizationPredictionTransformer(IHostEnvironment env, MatrixFactorizationModelParameters model, DataViewSchema trainSchema,
            string matrixColumnIndexColumnName, string matrixRowIndexColumnName, string scoreColumnNameSuffix = "")
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MatrixFactorizationPredictionTransformer)), model, trainSchema)
        {
            Host.CheckNonEmpty(matrixColumnIndexColumnName, nameof(matrixRowIndexColumnName));
            Host.CheckNonEmpty(matrixColumnIndexColumnName, nameof(matrixRowIndexColumnName));
 
            MatrixColumnIndexColumnName = matrixColumnIndexColumnName;
            MatrixRowIndexColumnName = matrixRowIndexColumnName;
 
            if (!trainSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol))
                throw Host.ExceptSchemaMismatch(nameof(MatrixColumnIndexColumnName), "matrixColumnIndex", MatrixColumnIndexColumnName);
            MatrixColumnIndexColumnType = trainSchema[xCol].Type;
            if (!trainSchema.TryGetColumnIndex(MatrixRowIndexColumnName, out int yCol))
                throw Host.ExceptSchemaMismatch(nameof(yCol), "matrixRowIndex", MatrixRowIndexColumnName);
            MatrixRowIndexColumnType = trainSchema[yCol].Type;
 
            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
 
            var schema = GetSchema();
            var args = new GenericScorer.Arguments { Suffix = scoreColumnNameSuffix };
            Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema);
        }
 
        private RoleMappedSchema GetSchema()
        {
            var roles = new List<KeyValuePair<RoleMappedSchema.ColumnRole, string>>();
            roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixColumnIndexKind, MatrixColumnIndexColumnName));
            roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixRowIndexKind, MatrixRowIndexColumnName));
            var schema = new RoleMappedSchema(TrainSchema, roles);
            return schema;
        }
 
        /// <summary>
        /// The counter constructor of re-creating <see cref="MatrixFactorizationPredictionTransformer"/> from the context where
        /// the original transform is saved.
        /// </summary>
        private MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(MatrixFactorizationPredictionTransformer)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // string: the column name of matrix's column ids.
            // string: the column name of matrix's row ids.
 
            MatrixColumnIndexColumnName = ctx.LoadString();
            MatrixRowIndexColumnName = ctx.LoadString();
 
            if (!TrainSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol))
                throw Host.ExceptSchemaMismatch(nameof(MatrixColumnIndexColumnName), "matrixColumnIndex", MatrixColumnIndexColumnName);
            MatrixColumnIndexColumnType = TrainSchema[xCol].Type;
 
            if (!TrainSchema.TryGetColumnIndex(MatrixRowIndexColumnName, out int yCol))
                throw Host.ExceptSchemaMismatch(nameof(MatrixRowIndexColumnName), "matrixRowIndex", MatrixRowIndexColumnName);
            MatrixRowIndexColumnType = TrainSchema[yCol].Type;
 
            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);
 
            var schema = GetSchema();
            var args = new GenericScorer.Arguments { Suffix = "" };
            Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
        }
 
        /// <summary>
        /// Schema propagation for transformers.
        /// Returns the output schema of the data, if the input schema is like the one provided.
        /// </summary>
        public override DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
        {
            if (!inputSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "matrixColumnIndex", MatrixColumnIndexColumnName);
            if (!inputSchema.TryGetColumnIndex(MatrixRowIndexColumnName, out int yCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "matrixRowIndex", MatrixRowIndexColumnName);
 
            return Transform(new EmptyDataView(Host, inputSchema)).Schema;
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // model: prediction model.
            // stream: empty data view that contains train schema.
            // ids of strings: feature columns.
            // float: scorer threshold
            // id of string: scorer threshold column
 
            ctx.SaveModel(Model, DirModel);
            ctx.SaveBinaryStream(DirTransSchema, writer =>
            {
                using (var ch = Host.Start("Saving train schema"))
                {
                    var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true });
                    DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
                }
            });
 
            ctx.SaveString(MatrixColumnIndexColumnName);
            ctx.SaveString(MatrixRowIndexColumnName);
        }
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "MAFAPRED", // "MA"trix "FA"torization "PRED"iction
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(MatrixFactorizationPredictionTransformer).Assembly.FullName);
        }
        internal static MatrixFactorizationPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
            => new MatrixFactorizationPredictionTransformer(env, ctx);
 
    }
}