File: FactorizationMachine\FactorizationMachineInterface.netcoreapp.cs
Web Access
Project: src\src\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj (Microsoft.ML.CpuMath)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Runtime.Intrinsics.X86;
using Microsoft.ML.Internal.CpuMath.Core;
 
namespace Microsoft.ML.Internal.CpuMath.FactorizationMachine
{
    [BestFriend]
    internal static unsafe partial class FieldAwareFactorizationMachineInterface
    {
        public static void CalculateIntermediateVariables(int fieldCount, int latentDim, int count, int[] fieldIndices, int[] featureIndices, float[] featureValues,
            float[] linearWeights, AlignedArray latentWeights, AlignedArray latentSum, ref float response)
        {
            Contracts.AssertNonEmpty(fieldIndices);
            Contracts.AssertNonEmpty(featureValues);
            Contracts.AssertNonEmpty(featureIndices);
            Contracts.AssertNonEmpty(linearWeights);
            Contracts.Assert(Compat(latentWeights));
            Contracts.Assert(Compat(latentSum));
 
            unsafe
            {
                fixed (int* pf = &fieldIndices[0])
                fixed (int* pi = &featureIndices[0])
                fixed (float* px = &featureValues[0])
                fixed (float* pw = &linearWeights[0])
                fixed (float* pv = &latentWeights.Items[0])
                fixed (float* pq = &latentSum.Items[0])
                fixed (float* pr = &response)
                {
                    if (Avx.IsSupported)
                        AvxIntrinsics.CalculateIntermediateVariables(pf, pi, px, pw, Ptr(latentWeights, pv), Ptr(latentSum, pq), pr, fieldCount, latentDim, count);
                    else
                        CalculateIntermediateVariablesNative(fieldCount, latentDim, count, pf, pi, px, pw, Ptr(latentWeights, pv), Ptr(latentSum, pq), pr);
                }
            }
        }
 
        public static void CalculateGradientAndUpdate(float lambdaLinear, float lambdaLatent, float learningRate, int fieldCount, int latentDim,
            float weight, int count, int[] fieldIndices, int[] featureIndices, float[] featureValues, AlignedArray latentSum, float slope,
            float[] linearWeights, AlignedArray latentWeights, float[] linearAccumulatedSquaredGrads, AlignedArray latentAccumulatedSquaredGrads)
        {
            Contracts.AssertNonEmpty(fieldIndices);
            Contracts.AssertNonEmpty(featureIndices);
            Contracts.AssertNonEmpty(featureValues);
            Contracts.Assert(Compat(latentSum));
            Contracts.AssertNonEmpty(linearWeights);
            Contracts.Assert(Compat(latentWeights));
            Contracts.AssertNonEmpty(linearAccumulatedSquaredGrads);
            Contracts.Assert(Compat(latentAccumulatedSquaredGrads));
 
            unsafe
            {
                fixed (int* pf = &fieldIndices[0])
                fixed (int* pi = &featureIndices[0])
                fixed (float* px = &featureValues[0])
                fixed (float* pq = &latentSum.Items[0])
                fixed (float* pw = &linearWeights[0])
                fixed (float* pv = &latentWeights.Items[0])
                fixed (float* phw = &linearAccumulatedSquaredGrads[0])
                fixed (float* phv = &latentAccumulatedSquaredGrads.Items[0])
                {
                    if (Avx.IsSupported)
                        AvxIntrinsics.CalculateGradientAndUpdate(pf, pi, px, Ptr(latentSum, pq), pw, Ptr(latentWeights, pv),
                            phw, Ptr(latentAccumulatedSquaredGrads, phv), lambdaLinear, lambdaLatent, learningRate, fieldCount, latentDim, weight, count, slope);
                    else
                        CalculateGradientAndUpdateNative(lambdaLinear, lambdaLatent, learningRate, fieldCount, latentDim, weight, count, pf, pi, px,
                            Ptr(latentSum, pq), slope, pw, Ptr(latentWeights, pv), phw, Ptr(latentAccumulatedSquaredGrads, phv));
                }
            }
        }
    }
}