File: FactorizationMachine\AvxIntrinsics.cs
Web Access
Project: src\src\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj (Microsoft.ML.CpuMath)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using Microsoft.ML.Internal.CpuMath.Core;
 
namespace Microsoft.ML.Internal.CpuMath.FactorizationMachine
{
    internal static class AvxIntrinsics
    {
        private static readonly Vector256<float> _point5 = Vector256.Create(0.5f);
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<float> src2, Vector256<float> src3)
        {
            if (Fma.IsSupported)
            {
                return Fma.MultiplyAdd(src1, src2, src3);
            }
            else
            {
                Vector256<float> product = Avx.Multiply(src1, src2);
                return Avx.Add(product, src3);
            }
        }
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static Vector256<float> MultiplyAddNegated(Vector256<float> src1, Vector256<float> src2, Vector256<float> src3)
        {
            if (Fma.IsSupported)
            {
                return Fma.MultiplyAddNegated(src1, src2, src3);
            }
            else
            {
                Vector256<float> product = Avx.Multiply(src1, src2);
                return Avx.Subtract(src3, product);
            }
        }
 
        // This function implements Algorithm 1 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf.
        // Compute the output value of the field-aware factorization, as the sum of the linear part and the latent part.
        // The linear part is the inner product of linearWeights and featureValues.
        // The latent part is the sum of all intra-field interactions in one field f, for all fields possible
        public static unsafe void CalculateIntermediateVariables(int* fieldIndices, int* featureIndices, float* featureValues,
            float* linearWeights, float* latentWeights, float* latentSum, float* response, int fieldCount, int latentDim, int count)
        {
            Contracts.Assert(Avx.IsSupported);
 
            // The number of all possible fields.
            int m = fieldCount;
            int d = latentDim;
            int c = count;
            int* pf = fieldIndices;
            int* pi = featureIndices;
            float* px = featureValues;
            float* pw = linearWeights;
            float* pv = latentWeights;
            float* pq = latentSum;
            float linearResponse = 0;
            float latentResponse = 0;
 
            Unsafe.InitBlock(pq, 0, (uint)(m * m * d * sizeof(float)));
 
            Vector256<float> y = Vector256<float>.Zero;
            Vector256<float> tmp = Vector256<float>.Zero;
 
            for (int i = 0; i < c; i++)
            {
                int f = pf[i];
                int j = pi[i];
                linearResponse += pw[j] * px[i];
 
                Vector256<float> x = Avx.BroadcastScalarToVector256(px + i);
                Vector256<float> xx = Avx.Multiply(x, x);
 
                // tmp -= <v_j,f, v_j,f> * x * x
                int vBias = j * m * d + f * d;
 
                // j-th feature's latent vector in the f-th field hidden space.
                float* vjf = pv + vBias;
 
                for (int k = 0; k + 8 <= d; k += 8)
                {
                    Vector256<float> vjfBuffer = Avx.LoadVector256(vjf + k);
                    tmp = MultiplyAddNegated(Avx.Multiply(vjfBuffer, vjfBuffer), xx, tmp);
                }
 
                for (int fprime = 0; fprime < m; fprime++)
                {
                    vBias = j * m * d + fprime * d;
                    int qBias = f * m * d + fprime * d;
                    float* vjfprime = pv + vBias;
                    float* qffprime = pq + qBias;
 
                    // q_f,f' += v_j,f' * x
                    for (int k = 0; k + 8 <= d; k += 8)
                    {
                        Vector256<float> vjfprimeBuffer = Avx.LoadVector256(vjfprime + k);
                        Vector256<float> q = Avx.LoadVector256(qffprime + k);
                        q = MultiplyAdd(vjfprimeBuffer, x, q);
                        Avx.Store(qffprime + k, q);
                    }
                }
            }
 
            for (int f = 0; f < m; f++)
            {
                // tmp += <q_f,f, q_f,f>
                float* qff = pq + f * m * d + f * d;
                for (int k = 0; k + 8 <= d; k += 8)
                {
                    Vector256<float> qffBuffer = Avx.LoadVector256(qff + k);
 
                    // Intra-field interactions.
                    tmp = MultiplyAdd(qffBuffer, qffBuffer, tmp);
                }
 
                // y += <q_f,f', q_f',f>, f != f'
                // Whis loop handles inter - field interactions because f != f'.
                for (int fprime = f + 1; fprime < m; fprime++)
                {
                    float* qffprime = pq + f * m * d + fprime * d;
                    float* qfprimef = pq + fprime * m * d + f * d;
                    for (int k = 0; k + 8 <= d; k += 8)
                    {
                        // Inter-field interaction.
                        Vector256<float> qffprimeBuffer = Avx.LoadVector256(qffprime + k);
                        Vector256<float> qfprimefBuffer = Avx.LoadVector256(qfprimef + k);
                        y = MultiplyAdd(qffprimeBuffer, qfprimefBuffer, y);
                    }
                }
            }
 
            y = MultiplyAdd(_point5, tmp, y);
            tmp = Avx.Add(y, Avx.Permute2x128(y, y, 1));
            tmp = Avx.HorizontalAdd(tmp, tmp);
            y = Avx.HorizontalAdd(tmp, tmp);
            Sse.StoreScalar(&latentResponse, y.GetLower()); // The lowest slot is the response value.
            *response = linearResponse + latentResponse;
        }
 
        // This function implements Algorithm 2 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
        // Calculate the stochastic gradient and update the model.
        public static unsafe void CalculateGradientAndUpdate(int* fieldIndices, int* featureIndices, float* featureValues, float* latentSum, float* linearWeights,
            float* latentWeights, float* linearAccumulatedSquaredGrads, float* latentAccumulatedSquaredGrads, float lambdaLinear, float lambdaLatent, float learningRate,
            int fieldCount, int latentDim, float weight, int count, float slope)
        {
            Contracts.Assert(Avx.IsSupported);
 
            int m = fieldCount;
            int d = latentDim;
            int c = count;
            int* pf = fieldIndices;
            int* pi = featureIndices;
            float* px = featureValues;
            float* pq = latentSum;
            float* pw = linearWeights;
            float* pv = latentWeights;
            float* phw = linearAccumulatedSquaredGrads;
            float* phv = latentAccumulatedSquaredGrads;
 
            Vector256<float> wei = Vector256.Create(weight);
            Vector256<float> s = Vector256.Create(slope);
            Vector256<float> lr = Vector256.Create(learningRate);
            Vector256<float> lambdav = Vector256.Create(lambdaLatent);
 
            for (int i = 0; i < count; i++)
            {
                int f = pf[i];
                int j = pi[i];
 
                // Calculate gradient of linear term w_j.
                float g = weight * (lambdaLinear * pw[j] + slope * px[i]);
 
                // Accumulate the gradient of the linear term.
                phw[j] += g * g;
 
                // Perform ADAGRAD update rule to adjust linear term.
                pw[j] -= learningRate / MathF.Sqrt(phw[j]) * g;
 
                // Update latent term, v_j,f', f'=1,...,m.
                Vector256<float> x = Avx.BroadcastScalarToVector256(px + i);
 
                for (int fprime = 0; fprime < m; fprime++)
                {
                    float* vjfprime = pv + j * m * d + fprime * d;
                    float* hvjfprime = phv + j * m * d + fprime * d;
                    float* qfprimef = pq + fprime * m * d + f * d;
                    Vector256<float> sx = Avx.Multiply(s, x);
 
                    for (int k = 0; k + 8 <= d; k += 8)
                    {
                        Vector256<float> v = Avx.LoadVector256(vjfprime + k);
                        Vector256<float> q = Avx.LoadVector256(qfprimef + k);
 
                        // Calculate L2-norm regularization's gradient.
                        Vector256<float> gLatent = Avx.Multiply(lambdav, v);
 
                        Vector256<float> tmp = q;
 
                        // Calculate loss function's gradient.
                        if (fprime == f)
                            tmp = MultiplyAddNegated(v, x, q);
                        gLatent = MultiplyAdd(sx, tmp, gLatent);
                        gLatent = Avx.Multiply(wei, gLatent);
 
                        // Accumulate the gradient of latent vectors.
                        Vector256<float> h = MultiplyAdd(gLatent, gLatent, Avx.LoadVector256(hvjfprime + k));
 
                        // Perform ADAGRAD update rule to adjust latent vector.
                        v = MultiplyAddNegated(lr, Avx.Multiply(Avx.ReciprocalSqrt(h), gLatent), v);
                        Avx.Store(vjfprime + k, v);
                        Avx.Store(hvjfprime + k, h);
                    }
                }
            }
        }
    }
}