|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
// The exported function names need to be unique (can't be disambiguated based on signature), hence
// we introduce suffix letters to indicate the general patterns used.
// * A suffix means aligned and padded for SSE operations.
// * U suffix means unaligned and unpadded.
// * S suffix means sparse (unaligned) vector.
// * P suffix means sparse (unaligned) partial vector - the vector is only part of a larger sparse vector.
// * R suffix means sparse matrix.
// * C suffix means convolution matrix.
// * D suffix means convolution matrix, with implicit source padding.
// * Tran means the matrix is transposed.
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using Microsoft.ML.Internal.CpuMath.Core;
#pragma warning disable CS8981 // The type name only contains lower-cased ascii characters. Such names may become reserved for the language.
using nuint = System.UInt64;
#pragma warning restore CS8981 // The type name only contains lower-cased ascii characters. Such names may become reserved for the language.
namespace Microsoft.ML.Internal.CpuMath
{
internal static class SseIntrinsics
{
public static readonly uint[] LeadingAlignmentMask = new uint[16]
{
0x00000000, 0x00000000, 0x00000000, 0x00000000,
0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000,
0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000,
};
public static readonly uint[] TrailingAlignmentMask = new uint[16]
{
0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF,
0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
};
// The count of bytes in Vector128<T>, corresponding to _cbAlign in AlignedArray
private const int Vector128Alignment = 16;
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static bool HasCompatibleAlignment(AlignedArray alignedArray)
{
Contracts.AssertValue(alignedArray);
Contracts.Assert(alignedArray.Size > 0);
return (alignedArray.CbAlign % Vector128Alignment) == 0;
}
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase)
{
Contracts.AssertValue(alignedArray);
float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase);
Contracts.Assert(((long)alignedBase & (Vector128Alignment - 1)) == 0);
return alignedBase;
}
internal static readonly Vector128<float> AbsMask128 = Sse2.IsSupported ?
Vector128.Create(0x7FFFFFFF).AsSingle() :
Vector128.Create(BitConverter.Int32BitsToSingle(0x7FFFFFFF));
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
internal static unsafe Vector128<float> Load1(float* src, int* idx)
=> Vector128.CreateScalar(src[idx[0]]);
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
internal static unsafe Vector128<float> Load4(float* src, int* idx)
=> Vector128.Create(src[idx[0]], src[idx[1]], src[idx[2]], src[idx[3]]);
// The control byte shuffles the four 32-bit floats of x: ABCD -> BCDA.
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
internal static Vector128<float> Rotate(in Vector128<float> x)
=> Sse.Shuffle(x, x, 0x39);
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
internal static unsafe void Store4(in Vector128<float> x, float* dst, int* idx)
{
Sse.StoreScalar(dst + idx[0], x);
Vector128<float> rotated = Rotate(in x);
Sse.StoreScalar(dst + idx[1], rotated);
rotated = Rotate(in rotated);
Sse.StoreScalar(dst + idx[2], rotated);
rotated = Rotate(in rotated);
Sse.StoreScalar(dst + idx[3], rotated);
}
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
internal static Vector128<float> VectorSum128(in Vector128<float> vector)
{
if (Sse3.IsSupported)
{
Vector128<float> partialSum = Sse3.HorizontalAdd(vector, vector);
return Sse3.HorizontalAdd(partialSum, partialSum);
}
else
{
Vector128<float> partialSum = Sse.Add(vector, Sse.MoveHighToLow(vector, vector));
// The control byte shuffles the four 32-bit floats of partialSum: ABCD -> BADC.
return Sse.Add(partialSum, Sse.Shuffle(partialSum, partialSum, 0xB1));
}
}
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
internal static Vector128<float> VectorMax128(in Vector128<float> vector)
{
// The control byte shuffles the four 32-bit floats of partialMax: ABCD -> BADC.
Vector128<float> x1 = Sse.Shuffle(vector, vector, 0xB1);
// Performs element-wise maximum operation: The 1st and 3rd 32-bit slots become
// max(A, B) and max(C, D).
Vector128<float> partialMax = Sse.Max(vector, x1);
// The control byte shuffles the four 32-bit floats of partialMax: ABCD -> CAAA.
x1 = Sse.Shuffle(partialMax, partialMax, 0x02);
// Performs element-wise maximum operation: The 1st 32-bit slot becomes
// max(A, B, C, D).
return Sse.MaxScalar(partialMax, x1);
}
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
internal static Vector128<float> GetNewDst128(in Vector128<float> xDst1, in Vector128<float> xThreshold)
{
Vector128<float> signMask = Vector128.Create(-0.0f); // 0x8000 0000
Vector128<float> xSign = Sse.And(xDst1, signMask); // result = 0x8000 0000 if xDst1 is negative or 0x0000 0000 otherwise
Vector128<float> xDst1Abs = Sse.Xor(xDst1, xSign);
Vector128<float> xCond = Sse.CompareGreaterThan(xDst1Abs, xThreshold); // result = 0xFFFF FFFF if true
Vector128<float> x2 = Sse.Xor(xSign, xThreshold); // -xThreshold if xDst1 is negative and +xThreshold otherwise
return Sse.And(Sse.Subtract(xDst1, x2), xCond);
}
// Multiply matrix times vector into vector.
public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
Contracts.Assert(HasCompatibleAlignment(mat));
Contracts.Assert(HasCompatibleAlignment(src));
Contracts.Assert(HasCompatibleAlignment(dst));
fixed (float* pSrcStart = &src.Items[0])
fixed (float* pDstStart = &dst.Items[0])
fixed (float* pMatStart = &mat.Items[0])
{
float* psrc = GetAlignedBase(src, pSrcStart);
float* pdst = GetAlignedBase(dst, pDstStart);
float* pmat = GetAlignedBase(mat, pMatStart);
float* pSrcEnd = psrc + ccol;
float* pDstEnd = pdst + crow;
float* pDstCurrent = pdst;
float* pMatCurrent = pmat;
while (pDstCurrent < pDstEnd)
{
Vector128<float> res0 = Vector128<float>.Zero;
Vector128<float> res1 = res0;
Vector128<float> res2 = res0;
Vector128<float> res3 = res0;
float* pSrcCurrent = psrc;
while (pSrcCurrent < pSrcEnd)
{
float* pMatTemp = pMatCurrent;
Vector128<float> x01 = Sse.LoadAlignedVector128(pMatTemp);
Vector128<float> x11 = Sse.LoadAlignedVector128(pMatTemp += ccol);
Vector128<float> x21 = Sse.LoadAlignedVector128(pMatTemp += ccol);
Vector128<float> x31 = Sse.LoadAlignedVector128(pMatTemp += ccol);
Vector128<float> x02 = Sse.LoadAlignedVector128(pSrcCurrent);
res0 = Sse.Add(res0, Sse.Multiply(x01, x02));
res1 = Sse.Add(res1, Sse.Multiply(x11, x02));
res2 = Sse.Add(res2, Sse.Multiply(x21, x02));
res3 = Sse.Add(res3, Sse.Multiply(x31, x02));
pSrcCurrent += 4;
pMatCurrent += 4;
}
// Add up the entries of each, with the 4 results in res0
res0 = Sse3.HorizontalAdd(res0, res1);
res2 = Sse3.HorizontalAdd(res2, res3);
res0 = Sse3.HorizontalAdd(res0, res2);
Sse.StoreAligned(pDstCurrent, res0);
pDstCurrent += 4;
pMatCurrent += 3 * ccol;
}
}
}
// Partial sparse source vector.
public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan<int> rgposSrc, AlignedArray src,
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
{
// REVIEW: For extremely sparse inputs, interchanging the loops would
// likely be more efficient.
Contracts.Assert(HasCompatibleAlignment(mat));
Contracts.Assert(HasCompatibleAlignment(src));
Contracts.Assert(HasCompatibleAlignment(dst));
// REVIEW: For extremely sparse inputs, interchanging the loops would
// likely be more efficient.
fixed (float* pSrcStart = &src.Items[0])
fixed (float* pDstStart = &dst.Items[0])
fixed (float* pMatStart = &mat.Items[0])
fixed (int* pposSrc = &rgposSrc[0])
{
float* psrc = GetAlignedBase(src, pSrcStart);
float* pdst = GetAlignedBase(dst, pDstStart);
float* pmat = GetAlignedBase(mat, pMatStart);
int* pposMin = pposSrc + iposMin;
int* pposEnd = pposSrc + iposEnd;
float* pDstEnd = pdst + crow;
float* pm0 = pmat - posMin;
float* pSrcCurrent = psrc - posMin;
float* pDstCurrent = pdst;
while (pDstCurrent < pDstEnd)
{
float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
float* pm3 = pm2 + ccol;
Vector128<float> result = Vector128<float>.Zero;
int* ppos = pposMin;
while (ppos < pposEnd)
{
int col = *ppos;
Vector128<float> x1 = Vector128.Create(pm0[col], pm1[col], pm2[col], pm3[col]);
Vector128<float> x2 = Vector128.Create(pSrcCurrent[col]);
x2 = Sse.Multiply(x2, x1);
result = Sse.Add(result, x2);
ppos++;
}
Sse.StoreAligned(pDstCurrent, result);
pDstCurrent += 4;
pm0 += 4 * ccol;
}
}
}
public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
Contracts.Assert(HasCompatibleAlignment(mat));
Contracts.Assert(HasCompatibleAlignment(src));
Contracts.Assert(HasCompatibleAlignment(dst));
fixed (float* pSrcStart = &src.Items[0])
fixed (float* pDstStart = &dst.Items[0])
fixed (float* pMatStart = &mat.Items[0])
{
float* psrc = GetAlignedBase(src, pSrcStart);
float* pdst = GetAlignedBase(dst, pDstStart);
float* pmat = GetAlignedBase(mat, pMatStart);
float* pSrcEnd = psrc + ccol;
float* pDstEnd = pdst + crow;
float* pSrcCurrent = psrc;
float* pMatCurrent = pmat;
Vector128<float> x01 = Sse.LoadAlignedVector128(pSrcCurrent);
// Replicate each 32-bit slot of x01 (ABCD) into its own register.
Vector128<float> x11 = Sse.Shuffle(x01, x01, 0x55); // B
Vector128<float> x21 = Sse.Shuffle(x01, x01, 0xAA); // C
Vector128<float> x31 = Sse.Shuffle(x01, x01, 0xFF); // D
x01 = Sse.Shuffle(x01, x01, 0x00); // A
pSrcCurrent += 4;
float* pDstCurrent = pdst;
while (pDstCurrent < pDstEnd)
{
float* pMatTemp = pMatCurrent;
Vector128<float> x02 = Sse.LoadAlignedVector128(pMatTemp);
Vector128<float> x12 = Sse.LoadAlignedVector128(pMatTemp += crow);
Vector128<float> x22 = Sse.LoadAlignedVector128(pMatTemp += crow);
Vector128<float> x32 = Sse.LoadAlignedVector128(pMatTemp += crow);
x02 = Sse.Multiply(x01, x02);
x12 = Sse.Multiply(x11, x12);
x22 = Sse.Multiply(x21, x22);
x32 = Sse.Multiply(x31, x32);
x02 = Sse.Add(x02, x12);
x22 = Sse.Add(x22, x32);
x02 = Sse.Add(x02, x22);
Sse.StoreAligned(pDstCurrent, x02);
pDstCurrent += 4;
pMatCurrent += 4;
}
pMatCurrent += 3 * crow;
while (pSrcCurrent < pSrcEnd)
{
x01 = Sse.LoadAlignedVector128(pSrcCurrent);
// Replicate each 32-bit slot of x01 (ABCD) into its own register.
x11 = Sse.Shuffle(x01, x01, 0x55); // B
x21 = Sse.Shuffle(x01, x01, 0xAA); // C
x31 = Sse.Shuffle(x01, x01, 0xFF); // D
x01 = Sse.Shuffle(x01, x01, 0x00); // A
pDstCurrent = pdst;
while (pDstCurrent < pDstEnd)
{
float* pMatTemp = pMatCurrent;
Vector128<float> x02 = Sse.LoadAlignedVector128(pMatTemp);
Vector128<float> x12 = Sse.LoadAlignedVector128(pMatTemp += crow);
Vector128<float> x22 = Sse.LoadAlignedVector128(pMatTemp += crow);
Vector128<float> x32 = Sse.LoadAlignedVector128(pMatTemp += crow);
Vector128<float> x3 = Sse.LoadAlignedVector128(pDstCurrent);
x02 = Sse.Multiply(x01, x02);
x12 = Sse.Multiply(x11, x12);
x22 = Sse.Multiply(x21, x22);
x32 = Sse.Multiply(x31, x32);
x02 = Sse.Add(x02, x12);
x22 = Sse.Add(x22, x32);
x02 = Sse.Add(x02, x22);
x3 = Sse.Add(x02, x3);
Sse.StoreAligned(pDstCurrent, x3);
pDstCurrent += 4;
pMatCurrent += 4;
}
pMatCurrent += 3 * crow;
pSrcCurrent += 4;
}
}
}
// dst[i] += scale
public static unsafe void AddScalarU(float scalar, Span<float> dst)
{
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pDstEnd = pdst + dst.Length;
float* pDstCurrent = pdst;
float* pVectorizationEnd = pDstEnd - 4;
Vector128<float> scalarVector = Vector128.Create(scalar);
while (pDstCurrent <= pVectorizationEnd)
{
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
dstVector = Sse.Add(dstVector, scalarVector);
Sse.Store(pDstCurrent, dstVector);
pDstCurrent += 4;
}
while (pDstCurrent < pDstEnd)
{
Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
dstVector = Sse.AddScalar(dstVector, scalarVector);
Sse.StoreScalar(pDstCurrent, dstVector);
pDstCurrent++;
}
}
}
public static unsafe void Scale(float scale, Span<float> dst)
{
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
fixed (float* pd = &MemoryMarshal.GetReference(dst))
{
float* pDstCurrent = pd;
int length = dst.Length;
Vector128<float> scaleVector128 = Vector128.Create(scale);
nuint address = (nuint)(pd);
int misalignment = (int)(address % 16);
int remainder = 0;
if ((misalignment & 3) != 0)
{
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
remainder = length % 4;
for (float* pEnd = pd + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 4)
{
Vector128<float> temp = Sse.LoadVector128(pDstCurrent);
temp = Sse.Multiply(scaleVector128, temp);
Sse.Store(pDstCurrent, temp);
}
}
else
{
if (misalignment != 0)
{
// Handle cases where the data is not 128-bit aligned by doing an unaligned read and then
// masking any elements that will be included in the first aligned read
misalignment >>= 2;
misalignment = 4 - misalignment;
Vector128<float> result = Sse.LoadVector128(pDstCurrent);
Vector128<float> leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4));
Vector128<float> trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4));
Vector128<float> temp = Sse.And(result, trailingMask);
result = Sse.Multiply(scaleVector128, result);
// Masking operation is done at the end to avoid doing an Or operation with negative Zero.
result = Sse.And(result, leadingMask);
result = Sse.Or(result, temp);
Sse.Store(pDstCurrent, result);
pDstCurrent += misalignment;
length -= misalignment;
}
if (length > 3)
{
// Handle all the 128-bit blocks that we can now that we have offset to an aligned address
remainder = length % 4;
for (float* pEnd = pDstCurrent + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 4)
{
// If we aren't using the VEX-encoding, the JIT will only fold away aligned loads
// (due to semantics of the legacy encoding).
// We don't need an assert, since the instruction will throw for unaligned inputs.
Vector128<float> temp = Sse.LoadAlignedVector128(pDstCurrent);
temp = Sse.Multiply(scaleVector128, temp);
Sse.Store(pDstCurrent, temp);
}
}
else
{
// Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not
// 128-bit aligned. This means we can't do any aligned loads and will just end up doing two
// unaligned loads where we mask the input each time.
remainder = length;
}
}
if (remainder != 0)
{
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
// unaligned load will read to the end of the array and then mask out any elements already processed
pDstCurrent -= (4 - remainder);
Vector128<float> result = Sse.LoadVector128(pDstCurrent);
Vector128<float> trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4));
Vector128<float> leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4));
Vector128<float> temp = Sse.And(result, leadingMask);
result = Sse.Multiply(scaleVector128, result);
// Masking operation is done at the end to avoid doing an Or operation with negative Zero.
result = Sse.And(result, trailingMask);
result = Sse.Or(result, temp);
Sse.Store(pDstCurrent, result);
}
}
}
public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
{
Contracts.Assert(count <= src.Length);
Contracts.Assert(count <= dst.Length);
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pDstEnd = pdst + count;
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pVectorizationEnd = pDstEnd - 4;
Vector128<float> scaleVector = Vector128.Create(scale);
while (pDstCurrent <= pVectorizationEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Multiply(srcVector, scaleVector);
Sse.Store(pDstCurrent, srcVector);
pSrcCurrent += 4;
pDstCurrent += 4;
}
while (pDstCurrent < pDstEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
srcVector = Sse.MultiplyScalar(srcVector, scaleVector);
Sse.StoreScalar(pDstCurrent, srcVector);
pSrcCurrent++;
pDstCurrent++;
}
}
}
// dst[i] = a * (dst[i] + b)
public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
{
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pDstEnd = pdst + dst.Length;
float* pDstCurrent = pdst;
float* pVectorizationEnd = pDstEnd - 4;
Vector128<float> aVector = Vector128.Create(a);
Vector128<float> bVector = Vector128.Create(b);
while (pDstCurrent <= pVectorizationEnd)
{
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
dstVector = Sse.Add(dstVector, bVector);
dstVector = Sse.Multiply(dstVector, aVector);
Sse.Store(pDstCurrent, dstVector);
pDstCurrent += 4;
}
while (pDstCurrent < pDstEnd)
{
Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
dstVector = Sse.AddScalar(dstVector, bVector);
dstVector = Sse.MultiplyScalar(dstVector, aVector);
Sse.StoreScalar(pDstCurrent, dstVector);
pDstCurrent++;
}
}
}
public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
{
Contracts.Assert(count <= src.Length);
Contracts.Assert(count <= dst.Length);
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pEnd = pdst + count;
Vector128<float> scaleVector = Vector128.Create(scale);
while (pDstCurrent + 4 <= pEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
srcVector = Sse.Multiply(srcVector, scaleVector);
dstVector = Sse.Add(dstVector, srcVector);
Sse.Store(pDstCurrent, dstVector);
pDstCurrent += 4;
pSrcCurrent += 4;
}
while (pDstCurrent < pEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
srcVector = Sse.MultiplyScalar(srcVector, scaleVector);
dstVector = Sse.AddScalar(dstVector, srcVector);
Sse.StoreScalar(pDstCurrent, dstVector);
pDstCurrent++;
pSrcCurrent++;
}
}
}
public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<float> dst, Span<float> result, int count)
{
Contracts.Assert(count <= src.Length);
Contracts.Assert(count <= dst.Length);
Contracts.Assert(count <= result.Length);
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (float* pres = &MemoryMarshal.GetReference(result))
{
float* pResEnd = pres + count;
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pResCurrent = pres;
Vector128<float> scaleVector = Vector128.Create(scale);
while (pResCurrent + 4 <= pResEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
srcVector = Sse.Multiply(srcVector, scaleVector);
dstVector = Sse.Add(dstVector, srcVector);
Sse.Store(pResCurrent, dstVector);
pSrcCurrent += 4;
pDstCurrent += 4;
pResCurrent += 4;
}
while (pResCurrent < pResEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
srcVector = Sse.MultiplyScalar(srcVector, scaleVector);
dstVector = Sse.AddScalar(dstVector, srcVector);
Sse.StoreScalar(pResCurrent, dstVector);
pSrcCurrent++;
pDstCurrent++;
pResCurrent++;
}
}
}
public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
{
Contracts.Assert(count <= src.Length);
Contracts.Assert(count <= dst.Length);
Contracts.Assert(count <= idx.Length);
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
int* pIdxCurrent = pidx;
float* pDstCurrent = pdst;
int* pEnd = pidx + count;
Vector128<float> scaleVector = Vector128.Create(scale);
while (pIdxCurrent + 4 <= pEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Load4(pDstCurrent, pIdxCurrent);
srcVector = Sse.Multiply(srcVector, scaleVector);
dstVector = Sse.Add(dstVector, srcVector);
Store4(in dstVector, pDstCurrent, pIdxCurrent);
pIdxCurrent += 4;
pSrcCurrent += 4;
}
while (pIdxCurrent < pEnd)
{
pDstCurrent[*pIdxCurrent] += scale * (*pSrcCurrent);
pIdxCurrent++;
pSrcCurrent++;
}
}
}
public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int count)
{
Contracts.Assert(count <= src.Length);
Contracts.Assert(count <= dst.Length);
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pEnd = psrc + count;
while (pSrcCurrent + 4 <= pEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Vector128<float> result = Sse.Add(srcVector, dstVector);
Sse.Store(pDstCurrent, result);
pSrcCurrent += 4;
pDstCurrent += 4;
}
while (pSrcCurrent < pEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
Vector128<float> result = Sse.AddScalar(srcVector, dstVector);
Sse.StoreScalar(pDstCurrent, result);
pSrcCurrent++;
pDstCurrent++;
}
}
}
public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
{
Contracts.Assert(count <= src.Length);
Contracts.Assert(count <= dst.Length);
Contracts.Assert(count <= idx.Length);
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
int* pIdxCurrent = pidx;
float* pDstCurrent = pdst;
int* pEnd = pidx + count;
while (pIdxCurrent + 4 <= pEnd)
{
Vector128<float> dstVector = Load4(pDstCurrent, pIdxCurrent);
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
dstVector = Sse.Add(dstVector, srcVector);
Store4(in dstVector, pDstCurrent, pIdxCurrent);
pIdxCurrent += 4;
pSrcCurrent += 4;
}
while (pIdxCurrent < pEnd)
{
pDstCurrent[*pIdxCurrent] += *pSrcCurrent;
pIdxCurrent++;
pSrcCurrent++;
}
}
}
public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan<float> src2, Span<float> dst, int count)
{
Contracts.Assert(count <= src1.Length);
Contracts.Assert(count <= src2.Length);
Contracts.Assert(count <= dst.Length);
fixed (float* psrc1 = &MemoryMarshal.GetReference(src1))
fixed (float* psrc2 = &MemoryMarshal.GetReference(src2))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrc1Current = psrc1;
float* pSrc2Current = psrc2;
float* pDstCurrent = pdst;
float* pEnd = pdst + count;
while (pDstCurrent + 4 <= pEnd)
{
Vector128<float> src1Vector = Sse.LoadVector128(pSrc1Current);
Vector128<float> src2Vector = Sse.LoadVector128(pSrc2Current);
src2Vector = Sse.Multiply(src1Vector, src2Vector);
Sse.Store(pDstCurrent, src2Vector);
pSrc1Current += 4;
pSrc2Current += 4;
pDstCurrent += 4;
}
while (pDstCurrent < pEnd)
{
Vector128<float> src1Vector = Sse.LoadScalarVector128(pSrc1Current);
Vector128<float> src2Vector = Sse.LoadScalarVector128(pSrc2Current);
src2Vector = Sse.MultiplyScalar(src1Vector, src2Vector);
Sse.StoreScalar(pDstCurrent, src2Vector);
pSrc1Current++;
pSrc2Current++;
pDstCurrent++;
}
}
}
public static unsafe float Sum(ReadOnlySpan<float> src)
{
fixed (float* pSrc = &MemoryMarshal.GetReference(src))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
float* pValues = pSrc;
int length = src.Length;
Vector128<float> result = Vector128<float>.Zero;
nuint address = (nuint)(pValues);
int misalignment = (int)(address % 16);
int remainder = 0;
if ((misalignment & 3) != 0)
{
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
remainder = length % 4;
for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
{
result = Sse.Add(result, Sse.LoadVector128(pValues));
}
}
else
{
if (misalignment != 0)
{
// Handle cases where the data is not 128-bit aligned by doing an unaligned read and then
// masking any elements that will be included in the first aligned read
misalignment >>= 2;
misalignment = 4 - misalignment;
Vector128<float> mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4));
Vector128<float> temp = Sse.And(mask, Sse.LoadVector128(pValues));
result = Sse.Add(result, temp);
pValues += misalignment;
length -= misalignment;
}
if (length > 3)
{
// Handle all the 128-bit blocks that we can now that we have offset to an aligned address
remainder = length % 4;
for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
{
// If we aren't using the VEX-encoding, the JIT will only fold away aligned loads
// (due to semantics of the legacy encoding).
// We don't need an assert, since the instruction will throw for unaligned inputs.
result = Sse.Add(result, Sse.LoadAlignedVector128(pValues));
}
}
else
{
// Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not
// 128-bit aligned. This means we can't do any aligned loads and will just end up doing two
// unaligned loads where we mask the input each time.
remainder = length;
}
}
if (remainder != 0)
{
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
// unaligned load will read to the end of the array and then mask out any elements already processed
pValues -= (4 - remainder);
Vector128<float> mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4));
Vector128<float> temp = Sse.And(mask, Sse.LoadVector128(pValues));
result = Sse.Add(result, temp);
}
// Sum all the elements together and return the result
result = VectorSum128(in result);
return result.ToScalar();
}
}
public static unsafe float SumSqU(ReadOnlySpan<float> src)
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Vector128<float> result = Vector128<float>.Zero;
while (pSrcCurrent + 4 <= pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
result = Sse.Add(result, Sse.Multiply(srcVector, srcVector));
pSrcCurrent += 4;
}
result = VectorSum128(in result);
while (pSrcCurrent < pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
result = Sse.AddScalar(result, Sse.MultiplyScalar(srcVector, srcVector));
pSrcCurrent++;
}
return result.ToScalar();
}
}
public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src)
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Vector128<float> result = Vector128<float>.Zero;
Vector128<float> meanVector = Vector128.Create(mean);
while (pSrcCurrent + 4 <= pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Subtract(srcVector, meanVector);
result = Sse.Add(result, Sse.Multiply(srcVector, srcVector));
pSrcCurrent += 4;
}
result = VectorSum128(in result);
while (pSrcCurrent < pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
srcVector = Sse.SubtractScalar(srcVector, meanVector);
result = Sse.AddScalar(result, Sse.MultiplyScalar(srcVector, srcVector));
pSrcCurrent++;
}
return result.ToScalar();
}
}
public static unsafe float SumAbsU(ReadOnlySpan<float> src)
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Vector128<float> result = Vector128<float>.Zero;
while (pSrcCurrent + 4 <= pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
result = Sse.Add(result, Sse.And(srcVector, AbsMask128));
pSrcCurrent += 4;
}
result = VectorSum128(in result);
while (pSrcCurrent < pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
result = Sse.AddScalar(result, Sse.And(srcVector, AbsMask128));
pSrcCurrent++;
}
return result.ToScalar();
}
}
public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src)
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Vector128<float> result = Vector128<float>.Zero;
Vector128<float> meanVector = Vector128.Create(mean);
while (pSrcCurrent + 4 <= pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Subtract(srcVector, meanVector);
result = Sse.Add(result, Sse.And(srcVector, AbsMask128));
pSrcCurrent += 4;
}
result = VectorSum128(in result);
while (pSrcCurrent < pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
srcVector = Sse.SubtractScalar(srcVector, meanVector);
result = Sse.AddScalar(result, Sse.And(srcVector, AbsMask128));
pSrcCurrent++;
}
return result.ToScalar();
}
}
public static unsafe float MaxAbsU(ReadOnlySpan<float> src)
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Vector128<float> result = Vector128<float>.Zero;
while (pSrcCurrent + 4 <= pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
result = Sse.Max(result, Sse.And(srcVector, AbsMask128));
pSrcCurrent += 4;
}
result = VectorMax128(in result);
while (pSrcCurrent < pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
result = Sse.MaxScalar(result, Sse.And(srcVector, AbsMask128));
pSrcCurrent++;
}
return result.ToScalar();
}
}
public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Vector128<float> result = Vector128<float>.Zero;
Vector128<float> meanVector = Vector128.Create(mean);
while (pSrcCurrent + 4 <= pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Subtract(srcVector, meanVector);
result = Sse.Max(result, Sse.And(srcVector, AbsMask128));
pSrcCurrent += 4;
}
result = VectorMax128(in result);
while (pSrcCurrent < pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
srcVector = Sse.SubtractScalar(srcVector, meanVector);
result = Sse.MaxScalar(result, Sse.And(srcVector, AbsMask128));
pSrcCurrent++;
}
return result.ToScalar();
}
}
public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
{
Contracts.Assert(count <= src.Length);
Contracts.Assert(count <= dst.Length);
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pSrcEnd = psrc + count;
Vector128<float> result = Vector128<float>.Zero;
while (pSrcCurrent + 4 <= pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
result = Sse.Add(result, Sse.Multiply(srcVector, dstVector));
pSrcCurrent += 4;
pDstCurrent += 4;
}
result = VectorSum128(in result);
while (pSrcCurrent < pSrcEnd)
{
Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
result = Sse.AddScalar(result, Sse.MultiplyScalar(srcVector, dstVector));
pSrcCurrent++;
pDstCurrent++;
}
return result.ToScalar();
}
}
public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, ReadOnlySpan<int> idx, int count)
{
Contracts.Assert(count <= src.Length);
Contracts.Assert(count <= dst.Length);
Contracts.Assert(count <= idx.Length);
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
int* pIdxCurrent = pidx;
int* pIdxEnd = pidx + count;
Vector128<float> result = Vector128<float>.Zero;
while (pIdxCurrent + 4 <= pIdxEnd)
{
Vector128<float> srcVector = Load4(pSrcCurrent, pIdxCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
result = Sse.Add(result, Sse.Multiply(srcVector, dstVector));
pIdxCurrent += 4;
pDstCurrent += 4;
}
result = VectorSum128(in result);
while (pIdxCurrent < pIdxEnd)
{
Vector128<float> srcVector = Load1(pSrcCurrent, pIdxCurrent);
Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
result = Sse.AddScalar(result, Sse.MultiplyScalar(srcVector, dstVector));
pIdxCurrent++;
pDstCurrent++;
}
return result.ToScalar();
}
}
public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
{
Contracts.Assert(count <= src.Length);
Contracts.Assert(count <= dst.Length);
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pSrcEnd = psrc + count;
Vector128<float> sqDistanceVector = Vector128<float>.Zero;
while (pSrcCurrent + 4 <= pSrcEnd)
{
Vector128<float> distanceVector = Sse.Subtract(Sse.LoadVector128(pSrcCurrent),
Sse.LoadVector128(pDstCurrent));
sqDistanceVector = Sse.Add(sqDistanceVector,
Sse.Multiply(distanceVector, distanceVector));
pSrcCurrent += 4;
pDstCurrent += 4;
}
sqDistanceVector = VectorSum128(in sqDistanceVector);
float norm = sqDistanceVector.ToScalar();
while (pSrcCurrent < pSrcEnd)
{
float distance = (*pSrcCurrent) - (*pDstCurrent);
norm += distance * distance;
pSrcCurrent++;
pDstCurrent++;
}
return norm;
}
}
public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlySpan<float> src, float threshold, Span<float> v, Span<float> w)
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst1 = &MemoryMarshal.GetReference(v))
fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
{
float* pSrcEnd = psrc + count;
float* pSrcCurrent = psrc;
float* pDst1Current = pdst1;
float* pDst2Current = pdst2;
Vector128<float> xPrimal = Vector128.Create(primalUpdate);
Vector128<float> signMask = Vector128.Create(-0.0f); // 0x8000 0000
Vector128<float> xThreshold = Vector128.Create(threshold);
while (pSrcCurrent + 4 <= pSrcEnd)
{
Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);
Vector128<float> xDst1 = Sse.LoadVector128(pDst1Current);
xDst1 = Sse.Add(xDst1, Sse.Multiply(xSrc, xPrimal));
Vector128<float> xDst2 = GetNewDst128(xDst1, xThreshold);
Sse.Store(pDst1Current, xDst1);
Sse.Store(pDst2Current, xDst2);
pSrcCurrent += 4;
pDst1Current += 4;
pDst2Current += 4;
}
while (pSrcCurrent < pSrcEnd)
{
*pDst1Current += (*pSrcCurrent) * primalUpdate;
float dst1 = *pDst1Current;
*pDst2Current = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0;
pSrcCurrent++;
pDst1Current++;
pDst2Current++;
}
}
}
public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnlySpan<float> src, ReadOnlySpan<int> indices, float threshold, Span<float> v, Span<float> w)
{
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (int* pidx = &MemoryMarshal.GetReference(indices))
fixed (float* pdst1 = &MemoryMarshal.GetReference(v))
fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
{
int* pIdxEnd = pidx + count;
float* pSrcCurrent = psrc;
int* pIdxCurrent = pidx;
Vector128<float> xPrimal = Vector128.Create(primalUpdate);
Vector128<float> signMask = Vector128.Create(-0.0f); // 0x8000 0000
Vector128<float> xThreshold = Vector128.Create(threshold);
while (pIdxCurrent + 4 <= pIdxEnd)
{
Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);
Vector128<float> xDst1 = Load4(pdst1, pIdxCurrent);
xDst1 = Sse.Add(xDst1, Sse.Multiply(xSrc, xPrimal));
Vector128<float> xDst2 = GetNewDst128(xDst1, xThreshold);
Store4(in xDst1, pdst1, pIdxCurrent);
Store4(in xDst2, pdst2, pIdxCurrent);
pIdxCurrent += 4;
pSrcCurrent += 4;
}
while (pIdxCurrent < pIdxEnd)
{
int index = *pIdxCurrent;
pdst1[index] += (*pSrcCurrent) * primalUpdate;
float dst1 = pdst1[index];
pdst2[index] = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0;
pIdxCurrent++;
pSrcCurrent++;
}
}
}
}
}
|