File: FactorizationMachine\FieldAwareFactorizationMachineUtils.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers
{
    internal sealed class FieldAwareFactorizationMachineUtils
    {
        internal static int GetAlignedVectorLength(int length)
        {
            int res = length % 4;
            if (res == 0)
                return length;
            else
                return length + (4 - res);
        }
 
        internal static bool LoadOneExampleIntoBuffer(ValueGetter<VBuffer<float>>[] getters, VBuffer<float> featureBuffer, bool norm, ref int count,
            int[] fieldIndexBuffer, int[] featureIndexBuffer, float[] featureValueBuffer)
        {
            count = 0;
            float featureNorm = 0;
            int bias = 0;
            float annihilation = 0;
            for (int f = 0; f < getters.Length; f++)
            {
                getters[f](ref featureBuffer);
                foreach (var pair in featureBuffer.Items())
                {
                    fieldIndexBuffer[count] = f;
                    featureIndexBuffer[count] = bias + pair.Key;
                    featureValueBuffer[count] = pair.Value;
                    featureNorm += pair.Value * pair.Value;
                    annihilation += pair.Value - pair.Value;
                    count++;
                }
                bias += featureBuffer.Length;
            }
            featureNorm = MathUtils.Sqrt(featureNorm);
            if (norm)
            {
                for (int i = 0; i < count; i++)
                    featureValueBuffer[i] /= featureNorm;
            }
            return FloatUtils.IsFinite(annihilation);
        }
    }
 
    internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBoundRowMapper
    {
        private readonly FieldAwareFactorizationMachineModelParameters _pred;
 
        public RoleMappedSchema InputRoleMappedSchema { get; }
 
        public DataViewSchema OutputSchema { get; }
 
        public DataViewSchema InputSchema => InputRoleMappedSchema.Schema;
 
        public ISchemaBindableMapper Bindable => _pred;
 
        private readonly DataViewSchema.Column[] _columns;
        private readonly List<int> _inputColumnIndexes;
        private readonly IHostEnvironment _env;
 
        public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleMappedSchema schema,
            DataViewSchema outputSchema, FieldAwareFactorizationMachineModelParameters pred)
        {
            Contracts.AssertValue(env);
            Contracts.AssertValue(schema);
            Contracts.CheckParam(outputSchema.Count == 2, nameof(outputSchema));
            Contracts.CheckParam(outputSchema[0].Type is NumberDataViewType, nameof(outputSchema));
            Contracts.CheckParam(outputSchema[1].Type is NumberDataViewType, nameof(outputSchema));
            Contracts.AssertValue(pred);
 
            _env = env;
            _columns = schema.GetColumns(RoleMappedSchema.ColumnRole.Feature).ToArray();
            _pred = pred;
 
            var inputFeatureColumns = _columns.Select(c => new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, c.Name)).ToList();
            InputRoleMappedSchema = new RoleMappedSchema(schema.Schema, inputFeatureColumns);
            OutputSchema = outputSchema;
 
            _inputColumnIndexes = new List<int>();
            foreach (var kvp in inputFeatureColumns)
            {
                if (schema.Schema.TryGetColumnIndex(kvp.Value, out int index))
                    _inputColumnIndexes.Add(index);
            }
        }
 
        DataViewRow ISchemaBoundRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
        {
            var latentSum = new AlignedArray(_pred.FieldCount * _pred.FieldCount * _pred.LatentDimAligned, 16);
            var featureBuffer = new VBuffer<float>();
            var featureFieldBuffer = new int[_pred.FeatureCount];
            var featureIndexBuffer = new int[_pred.FeatureCount];
            var featureValueBuffer = new float[_pred.FeatureCount];
            var inputGetters = new ValueGetter<VBuffer<float>>[_pred.FieldCount];
 
            var activeIndices = activeColumns.Select(c => c.Index).ToArray();
            var active0 = activeIndices.Contains(0);
            var active1 = activeIndices.Contains(1);
 
            if (active0 || active1)
            {
                for (int f = 0; f < _pred.FieldCount; f++)
                    inputGetters[f] = input.GetGetter<VBuffer<float>>(input.Schema[_inputColumnIndexes[f]]);
            }
 
            var getters = new Delegate[2];
            if (active0)
            {
                ValueGetter<float> responseGetter = (ref float value) =>
                {
                    value = _pred.CalculateResponse(inputGetters, featureBuffer, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum);
                };
                getters[0] = responseGetter;
            }
            if (active1)
            {
                ValueGetter<float> probGetter = (ref float value) =>
                {
                    value = _pred.CalculateResponse(inputGetters, featureBuffer, featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum);
                    value = MathUtils.SigmoidSlow(value);
                };
                getters[1] = probGetter;
            }
 
            return new SimpleRow(OutputSchema, input, getters);
        }
 
        /// <summary>
        /// Given a set of columns, return the input columns that are needed to generate those output columns.
        /// </summary>
        IEnumerable<DataViewSchema.Column> ISchemaBoundRowMapper.GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> columns)
        {
            if (columns.Count() == 0)
                return Enumerable.Empty<DataViewSchema.Column>();
 
            return InputSchema.Where(col => _inputColumnIndexes.Contains(col.Index));
        }
 
        public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
        {
            return InputRoleMappedSchema.GetColumnRoles().Select(kvp => new KeyValuePair<RoleMappedSchema.ColumnRole, string>(kvp.Key, kvp.Value.Name));
        }
    }
}