File: AvxIntrinsics.cs
Web Access
Project: src\src\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj (Microsoft.ML.CpuMath)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
// The exported function names need to be unique (can't be disambiguated based on signature), hence
// we introduce suffix letters to indicate the general patterns used.
// * A suffix means aligned and padded for SSE operations.
// * U suffix means unaligned and unpadded.
// * P suffix means sparse (unaligned) partial vector - the vector is only part of a larger sparse vector.
// * Tran means the matrix is transposed.
 
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using Microsoft.ML.Internal.CpuMath.Core;
#pragma warning disable CS8981 // The type name only contains lower-cased ascii characters. Such names may become reserved for the language.
using nuint = System.UInt64;
#pragma warning restore CS8981 // The type name only contains lower-cased ascii characters. Such names may become reserved for the language.
 
namespace Microsoft.ML.Internal.CpuMath
{
    internal static class AvxIntrinsics
    {
        public static readonly uint[] LeadingAlignmentMask = new uint[64]
        {
            0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
            0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
            0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
            0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
            0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
            0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000,
            0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000,
            0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000,
        };
 
        public static readonly uint[] TrailingAlignmentMask = new uint[64]
        {
            0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
            0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF,
            0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF,
            0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
            0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
            0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
            0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
            0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
        };
 
        private static readonly Vector256<float> _absMask256 = Vector256.Create(0x7FFFFFFF).AsSingle();
 
        private const int Vector256Alignment = 32;
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static bool HasCompatibleAlignment(AlignedArray alignedArray)
        {
            Contracts.AssertValue(alignedArray);
            Contracts.Assert(alignedArray.Size > 0);
            return (alignedArray.CbAlign % Vector256Alignment) == 0;
        }
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase)
        {
            Contracts.AssertValue(alignedArray);
            float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase);
            Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0);
            return alignedBase;
        }
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static Vector128<float> GetHigh(in Vector256<float> x)
            => Avx.ExtractVector128(x, 1);
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static unsafe Vector256<float> Load8(float* src, int* idx)
        {
            if (Avx2.IsSupported)
            {
                Vector256<int> idx256 = Avx.LoadVector256(idx);
                return Avx2.GatherVector256(src, idx256, 4);
            }
            else
            {
                return Vector256.Create(src[idx[0]], src[idx[1]], src[idx[2]], src[idx[3]], src[idx[4]], src[idx[5]], src[idx[6]], src[idx[7]]);
            }
        }
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static unsafe void Store8(in Vector256<float> x, float* dst, int* idx)
        {
            Vector128<float> tmp = x.GetLower();
            Sse.StoreScalar(dst + idx[0], tmp);
            tmp = SseIntrinsics.Rotate(in tmp);
            Sse.StoreScalar(dst + idx[1], tmp);
            tmp = SseIntrinsics.Rotate(in tmp);
            Sse.StoreScalar(dst + idx[2], tmp);
            tmp = SseIntrinsics.Rotate(in tmp);
            Sse.StoreScalar(dst + idx[3], tmp);
            tmp = GetHigh(in x);
            Sse.StoreScalar(dst + idx[4], tmp);
            tmp = SseIntrinsics.Rotate(in tmp);
            Sse.StoreScalar(dst + idx[5], tmp);
            tmp = SseIntrinsics.Rotate(in tmp);
            Sse.StoreScalar(dst + idx[6], tmp);
            tmp = SseIntrinsics.Rotate(in tmp);
            Sse.StoreScalar(dst + idx[7], tmp);
        }
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static Vector256<float> VectorSum256(in Vector256<float> vector)
        {
            Vector256<float> partialSum = Avx.HorizontalAdd(vector, vector);
            return Avx.HorizontalAdd(partialSum, partialSum);
        }
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static Vector256<float> VectorMax256(in Vector256<float> vector)
        {
            // The control byte shuffles the eight 32-bit floats of partialMax: ABCD|EFGH -> BADC|FEHG.
            Vector256<float> x1 = Avx.Shuffle(vector, vector, 0xB1);
 
            // Performs element-wise maximum operation: The 1st, 3rd, 5th, and 7th 32-bit slots become
            // max(A, B), max(C, D), max(E, F), and max(G, H).
            Vector256<float> partialMax = Avx.Max(vector, x1);
 
            // The control byte shuffles the eight 32-bit floats of partialMax: ABCD|EFGH -> CAAA|GEEE.
            x1 = Avx.Shuffle(partialMax, partialMax, 0x02);
 
            // Performs element-wise maximum operation: The 1st and 5th 32-bit slots become
            // max(max(A, B), max(C, D)) = max(A, B, C, D) and
            // max(max(E, F), max(G, H)) = max(E, F, G, H).
            return Avx.Max(partialMax, x1);
        }
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static Vector256<float> GetNewDst256(in Vector256<float> xDst1, in Vector256<float> xThreshold)
        {
            Vector256<float> signMask = Vector256.Create(-0.0f); // 0x8000 0000
            Vector256<float> xSign = Avx.And(xDst1, signMask); // result = 0x8000 0000 if xDst1 is negative or 0x0000 0000 otherwise
            Vector256<float> xDst1Abs = Avx.Xor(xDst1, xSign);
            Vector256<float> xCond = Avx.Compare(xDst1Abs, xThreshold, FloatComparisonMode.OrderedGreaterThanNonSignaling); // result = 0xFFFF FFFF if true
            Vector256<float> x2 = Avx.Xor(xSign, xThreshold); // -xThreshold if xDst1 is negative and +xThreshold otherwise
            return Avx.And(Avx.Subtract(xDst1, x2), xCond);
        }
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static unsafe Vector256<float> MultiplyAdd(float* psrc1, Vector256<float> src2, Vector256<float> src3)
        {
            if (Fma.IsSupported)
            {
                return Fma.MultiplyAdd(Avx.LoadVector256(psrc1), src2, src3);
            }
            else
            {
                Vector256<float> product = Avx.Multiply(src2, Avx.LoadVector256(psrc1));
                return Avx.Add(product, src3);
            }
        }
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<float> src2, Vector256<float> src3)
        {
            if (Fma.IsSupported)
            {
                return Fma.MultiplyAdd(src1, src2, src3);
            }
            else
            {
                Vector256<float> product = Avx.Multiply(src1, src2);
                return Avx.Add(product, src3);
            }
        }
 
        // Multiply matrix times vector into vector.
        public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
        {
            Contracts.Assert(HasCompatibleAlignment(mat));
            Contracts.Assert(HasCompatibleAlignment(src));
            Contracts.Assert(HasCompatibleAlignment(dst));
 
            fixed (float* pSrcStart = &src.Items[0])
            fixed (float* pDstStart = &dst.Items[0])
            fixed (float* pMatStart = &mat.Items[0])
            {
                float* psrc = GetAlignedBase(src, pSrcStart);
                float* pdst = GetAlignedBase(dst, pDstStart);
                float* pmat = GetAlignedBase(mat, pMatStart);
 
                float* pSrcEnd = psrc + ccol;
                float* pDstEnd = pdst + crow;
                float* pDstCurrent = pdst;
                float* pMatCurrent = pmat;
 
                while (pDstCurrent < pDstEnd)
                {
                    Vector256<float> res0 = Vector256<float>.Zero;
                    Vector256<float> res1 = res0;
                    Vector256<float> res2 = res0;
                    Vector256<float> res3 = res0;
 
                    float* pSrcCurrent = psrc;
 
                    while (pSrcCurrent < pSrcEnd)
                    {
                        float* pMatTemp = pMatCurrent;
                        Contracts.Assert(((nuint)(pMatTemp) % 32) == 0);
                        Contracts.Assert(((nuint)(pSrcCurrent) % 32) == 0);
 
                        // The JIT will only fold away unaligned loads due to the semantics behind
                        // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
                        // modern hardware has unaligned loads that are as fast as aligned loads,
                        // when it doesn't cross a cache-line/page boundary, we will just assert
                        // that the alignment is correct and allow for the more-efficient codegen.
                        Vector256<float> x01 = Avx.LoadVector256(pMatTemp);
                        Vector256<float> x11 = Avx.LoadVector256(pMatTemp += ccol);
                        Vector256<float> x21 = Avx.LoadVector256(pMatTemp += ccol);
                        Vector256<float> x31 = Avx.LoadVector256(pMatTemp += ccol);
                        Vector256<float> x02 = Avx.LoadVector256(pSrcCurrent);
 
                        res0 = MultiplyAdd(x01, x02, res0);
                        res1 = MultiplyAdd(x11, x02, res1);
                        res2 = MultiplyAdd(x21, x02, res2);
                        res3 = MultiplyAdd(x31, x02, res3);
 
                        pSrcCurrent += 8;
                        pMatCurrent += 8;
                    }
 
                    // Add up the entries of each, with the 4 results in res0
                    res0 = Avx.HorizontalAdd(res0, res1);
                    res2 = Avx.HorizontalAdd(res2, res3);
                    res0 = Avx.HorizontalAdd(res0, res2);
 
                    Vector128<float> sum = Sse.Add(res0.GetLower(), GetHigh(in res0));
                    Sse.StoreAligned(pDstCurrent, sum);
 
                    pDstCurrent += 4;
                    pMatCurrent += 3 * ccol;
                }
            }
        }
 
        // Partial sparse source vector.
        public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan<int> rgposSrc, AlignedArray src,
                                int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
        {
            // REVIEW: For extremely sparse inputs, interchanging the loops would
            // likely be more efficient.
            Contracts.Assert(HasCompatibleAlignment(mat));
            Contracts.Assert(HasCompatibleAlignment(src));
            Contracts.Assert(HasCompatibleAlignment(dst));
 
            fixed (float* pSrcStart = &src.Items[0])
            fixed (float* pDstStart = &dst.Items[0])
            fixed (float* pMatStart = &mat.Items[0])
            fixed (int* pposSrc = &rgposSrc[0])
            {
                float* psrc = GetAlignedBase(src, pSrcStart);
                float* pdst = GetAlignedBase(dst, pDstStart);
                float* pmat = GetAlignedBase(mat, pMatStart);
 
                int* pposMin = pposSrc + iposMin;
                int* pposEnd = pposSrc + iposEnd;
                float* pDstEnd = pdst + crow;
                float* pm0 = pmat - posMin;
                float* pSrcCurrent = psrc - posMin;
                float* pDstCurrent = pdst;
 
                while (pDstCurrent < pDstEnd)
                {
                    float* pm1 = pm0 + ccol;
                    float* pm2 = pm1 + ccol;
                    float* pm3 = pm2 + ccol;
                    Vector256<float> result = Vector256<float>.Zero;
 
                    int* ppos = pposMin;
 
                    while (ppos < pposEnd)
                    {
                        int col1 = *ppos;
                        int col2 = col1 + 4 * ccol;
                        Vector256<float> x1 = Vector256.Create(pm0[col1], pm1[col1], pm2[col1], pm3[col1],
                                                                pm0[col2], pm1[col2], pm2[col2], pm3[col2]);
                        Vector256<float> x2 = Vector256.Create(pSrcCurrent[col1]);
                        x2 = Avx.Multiply(x2, x1);
                        result = Avx.Add(result, x2);
 
                        ppos++;
                    }
 
                    Avx.StoreAligned(pDstCurrent, result);
                    pDstCurrent += 8;
                    pm0 += 8 * ccol;
                }
            }
        }
 
        public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
        {
            Contracts.Assert(HasCompatibleAlignment(mat));
            Contracts.Assert(HasCompatibleAlignment(src));
            Contracts.Assert(HasCompatibleAlignment(dst));
 
            fixed (float* pSrcStart = &src.Items[0])
            fixed (float* pDstStart = &dst.Items[0])
            fixed (float* pMatStart = &mat.Items[0])
            {
                float* psrc = GetAlignedBase(src, pSrcStart);
                float* pdst = GetAlignedBase(dst, pDstStart);
                float* pmat = GetAlignedBase(mat, pMatStart);
 
                float* pSrcEnd = psrc + ccol;
                float* pDstEnd = pdst + crow;
                float* pSrcCurrent = psrc;
                float* pMatCurrent = pmat;
 
                // We do 4-way unrolling
                Vector128<float> h01 = Sse.LoadAlignedVector128(pSrcCurrent);
                // Replicate each slot of h01 (ABCD) into its own register.
                Vector128<float> h11 = Sse.Shuffle(h01, h01, 0x55); // B
                Vector128<float> h21 = Sse.Shuffle(h01, h01, 0xAA); // C
                Vector128<float> h31 = Sse.Shuffle(h01, h01, 0xFF); // D
                h01 = Sse.Shuffle(h01, h01, 0x00); // A
 
                Vector256<float> x01 = Vector256.Create(h01, h01);
                Vector256<float> x11 = Vector256.Create(h11, h11);
                Vector256<float> x21 = Vector256.Create(h21, h21);
                Vector256<float> x31 = Vector256.Create(h31, h31);
 
                pSrcCurrent += 4;
 
                float* pDstCurrent = pdst;
 
                while (pDstCurrent < pDstEnd)
                {
                    float* pMatTemp = pMatCurrent;
                    Contracts.Assert(((nuint)(pMatTemp) % 32) == 0);
 
                    // The JIT will only fold away unaligned loads due to the semantics behind
                    // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
                    // modern hardware has unaligned loads that are as fast as aligned loads,
                    // when it doesn't cross a cache-line/page boundary, we will just assert
                    // that the alignment is correct and allow for the more-efficient codegen.
                    Vector256<float> x02 = Avx.LoadVector256(pMatTemp);
                    Vector256<float> x12 = Avx.LoadVector256(pMatTemp += crow);
                    Vector256<float> x22 = Avx.LoadVector256(pMatTemp += crow);
                    Vector256<float> x32 = Avx.LoadVector256(pMatTemp += crow);
 
                    x02 = Avx.Multiply(x01, x02);
                    x02 = MultiplyAdd(x11, x12, x02);
 
                    x22 = Avx.Multiply(x21, x22);
                    x22 = MultiplyAdd(x31, x32, x22);
 
                    x02 = Avx.Add(x02, x22);
                    Avx.StoreAligned(pDstCurrent, x02);
 
                    pDstCurrent += 8;
                    pMatCurrent += 8;
                }
 
                pMatCurrent += 3 * crow;
 
                while (pSrcCurrent < pSrcEnd)
                {
                    h01 = Sse.LoadAlignedVector128(pSrcCurrent);
                    // Replicate each slot of h01 (ABCD) into its own register.
                    h11 = Sse.Shuffle(h01, h01, 0x55); // B
                    h21 = Sse.Shuffle(h01, h01, 0xAA); // C
                    h31 = Sse.Shuffle(h01, h01, 0xFF); // D
                    h01 = Sse.Shuffle(h01, h01, 0x00); // A
 
                    x01 = Vector256.Create(h01, h01);
                    x11 = Vector256.Create(h11, h11);
                    x21 = Vector256.Create(h21, h21);
                    x31 = Vector256.Create(h31, h31);
 
                    pDstCurrent = pdst;
 
                    while (pDstCurrent < pDstEnd)
                    {
                        float* pMatTemp = pMatCurrent;
 
                        Contracts.Assert(((nuint)(pMatTemp) % 32) == 0);
                        Contracts.Assert(((nuint)(pDstCurrent) % 32) == 0);
 
                        // The JIT will only fold away unaligned loads due to the semantics behind
                        // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
                        // modern hardware has unaligned loads that are as fast as aligned loads,
                        // when it doesn't cross a cache-line/page boundary, we will just assert
                        // that the alignment is correct and allow for the more-efficient codegen.
                        Vector256<float> x02 = Avx.LoadVector256(pMatTemp);
                        Vector256<float> x12 = Avx.LoadVector256(pMatTemp += crow);
                        Vector256<float> x22 = Avx.LoadVector256(pMatTemp += crow);
                        Vector256<float> x32 = Avx.LoadVector256(pMatTemp += crow);
                        Vector256<float> x3 = Avx.LoadVector256(pDstCurrent);
 
                        x02 = Avx.Multiply(x01, x02);
                        x02 = MultiplyAdd(x11, x12, x02);
 
                        x22 = Avx.Multiply(x21, x22);
                        x22 = MultiplyAdd(x31, x32, x22);
 
                        x02 = Avx.Add(x02, x22);
                        x3 = Avx.Add(x02, x3);
                        Avx.StoreAligned(pDstCurrent, x3);
 
                        pDstCurrent += 8;
                        pMatCurrent += 8;
                    }
 
                    pMatCurrent += 3 * crow;
                    pSrcCurrent += 4;
                }
            }
        }
 
        // dst[i] += scale
        public static unsafe void AddScalarU(float scalar, Span<float> dst)
        {
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pDstEnd = pdst + dst.Length;
                float* pDstCurrent = pdst;
                float* pVectorizationEnd = pDstEnd - 4;
 
                Vector256<float> scalarVector256 = Vector256.Create(scalar);
 
                while (pDstCurrent + 8 <= pDstEnd)
                {
                    Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
                    dstVector = Avx.Add(dstVector, scalarVector256);
                    Avx.Store(pDstCurrent, dstVector);
 
                    pDstCurrent += 8;
                }
 
                Vector128<float> scalarVector128 = Vector128.Create(scalar);
 
                if (pDstCurrent <= pVectorizationEnd)
                {
                    Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
                    dstVector = Sse.Add(dstVector, scalarVector128);
                    Sse.Store(pDstCurrent, dstVector);
 
                    pDstCurrent += 4;
                }
 
                while (pDstCurrent < pDstEnd)
                {
                    Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
                    dstVector = Sse.AddScalar(dstVector, scalarVector128);
                    Sse.StoreScalar(pDstCurrent, dstVector);
 
                    pDstCurrent++;
                }
            }
        }
 
        public static unsafe void Scale(float scale, Span<float> dst)
        {
            fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
            fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
            fixed (float* pd = &MemoryMarshal.GetReference(dst))
            {
                float* pDstCurrent = pd;
                int length = dst.Length;
                Vector256<float> scaleVector256 = Vector256.Create(scale);
 
                nuint address = (nuint)(pd);
                int misalignment = (int)(address % 32);
                int remainder = 0;
 
                if ((misalignment & 3) != 0)
                {
                    // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
                    remainder = length % 8;
 
                    for (float* pEnd = pd + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 8)
                    {
                        Vector256<float> temp = Avx.LoadVector256(pDstCurrent);
                        temp = Avx.Multiply(scaleVector256, temp);
                        Avx.Store(pDstCurrent, temp);
                    }
                }
                else
                {
                    if (misalignment != 0)
                    {
                        // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
                        // masking any elements that will be included in the first aligned read
 
                        misalignment >>= 2;
                        misalignment = 8 - misalignment;
 
                        Vector256<float> result = Avx.LoadVector256(pDstCurrent);
 
                        Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
                        Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8));
 
                        Vector256<float> temp = Avx.And(result, trailingMask);
                        result = Avx.Multiply(scaleVector256, result);
 
                        // Masking operation is done at the end to avoid doing an Or operation with negative Zero.
                        result = Avx.And(result, leadingMask);
                        result = Avx.Or(result, temp);
 
                        Avx.Store(pDstCurrent, result);
 
                        pDstCurrent += misalignment;
                        length -= misalignment;
                    }
 
                    if (length > 7)
                    {
                        // Handle all the 256-bit blocks that we can now that we have offset to an aligned address
 
                        remainder = length % 8;
 
                        for (float* pEnd = pDstCurrent + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 8)
                        {
                            // The JIT will only fold away unaligned loads due to the semantics behind
                            // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
                            // modern hardware has unaligned loads that are as fast as aligned loads,
                            // when it doesn't cross a cache-line/page boundary, we will just assert
                            // that the alignment is correct and allow for the more-efficient codegen.
 
                            Contracts.Assert(((nuint)(pDstCurrent) % 32) == 0);
                            Vector256<float> temp = Avx.LoadVector256(pDstCurrent);
                            temp = Avx.Multiply(scaleVector256, temp);
                            Avx.Store(pDstCurrent, temp);
                        }
                    }
                    else
                    {
                        // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
                        // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
                        // unaligned loads where we mask the input each time.
                        remainder = length;
                    }
                }
 
                if (remainder != 0)
                {
                    // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
                    // unaligned load will read to the end of the array and then mask out any elements already processed
 
                    pDstCurrent -= (8 - remainder);
 
                    Vector256<float> result = Avx.LoadVector256(pDstCurrent);
 
                    Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
                    Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8));
 
                    Vector256<float> temp = Avx.And(result, leadingMask);
                    result = Avx.Multiply(scaleVector256, result);
 
                    // Masking operation is done at the end to avoid doing an Or operation with negative Zero.
                    result = Avx.And(result, trailingMask);
                    result = Avx.Or(result, temp);
 
                    Avx.Store(pDstCurrent, result);
                }
            }
        }
 
        public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
        {
            Contracts.Assert(count <= src.Length);
            Contracts.Assert(count <= dst.Length);
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pDstEnd = pdst + count;
                float* pSrcCurrent = psrc;
                float* pDstCurrent = pdst;
                float* pVectorizationEnd = pDstEnd - 4;
 
                Vector256<float> scaleVector256 = Vector256.Create(scale);
 
                while (pDstCurrent + 8 <= pDstEnd)
                {
                    Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
                    srcVector = Avx.Multiply(srcVector, scaleVector256);
                    Avx.Store(pDstCurrent, srcVector);
 
                    pSrcCurrent += 8;
                    pDstCurrent += 8;
                }
 
                Vector128<float> scaleVector128 = Vector128.Create(scale);
 
                if (pDstCurrent <= pVectorizationEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    srcVector = Sse.Multiply(srcVector, scaleVector128);
                    Sse.Store(pDstCurrent, srcVector);
 
                    pSrcCurrent += 4;
                    pDstCurrent += 4;
                }
 
                while (pDstCurrent < pDstEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    srcVector = Sse.MultiplyScalar(srcVector, scaleVector128);
                    Sse.StoreScalar(pDstCurrent, srcVector);
 
                    pSrcCurrent++;
                    pDstCurrent++;
                }
            }
        }
 
        // dst[i] = a * (dst[i] + b)
        public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
        {
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pDstEnd = pdst + dst.Length;
                float* pDstCurrent = pdst;
                float* pVectorizationEnd = pDstEnd - 4;
 
                Vector256<float> a256 = Vector256.Create(a);
                Vector256<float> b256 = Vector256.Create(b);
 
                while (pDstCurrent + 8 <= pDstEnd)
                {
                    Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
                    dstVector = Avx.Add(dstVector, b256);
                    dstVector = Avx.Multiply(dstVector, a256);
                    Avx.Store(pDstCurrent, dstVector);
 
                    pDstCurrent += 8;
                }
 
                Vector128<float> a128 = Vector128.Create(a);
                Vector128<float> b128 = Vector128.Create(b);
 
                if (pDstCurrent <= pVectorizationEnd)
                {
                    Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
                    dstVector = Sse.Add(dstVector, b128);
                    dstVector = Sse.Multiply(dstVector, a128);
                    Sse.Store(pDstCurrent, dstVector);
 
                    pDstCurrent += 4;
                }
 
                while (pDstCurrent < pDstEnd)
                {
                    Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
                    dstVector = Sse.AddScalar(dstVector, b128);
                    dstVector = Sse.MultiplyScalar(dstVector, a128);
                    Sse.StoreScalar(pDstCurrent, dstVector);
 
                    pDstCurrent++;
                }
            }
        }
 
        public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
        {
            Contracts.Assert(count <= src.Length);
            Contracts.Assert(count <= dst.Length);
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pSrcCurrent = psrc;
                float* pDstCurrent = pdst;
                float* pEnd = pdst + count;
 
                Vector256<float> scaleVector256 = Vector256.Create(scale);
 
                while (pDstCurrent + 8 <= pEnd)
                {
                    Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
 
                    dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
                    Avx.Store(pDstCurrent, dstVector);
 
                    pSrcCurrent += 8;
                    pDstCurrent += 8;
                }
 
                Vector128<float> scaleVector128 = Vector128.Create(scale);
 
                if (pDstCurrent + 4 <= pEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
 
                    srcVector = Sse.Multiply(srcVector, scaleVector128);
                    dstVector = Sse.Add(dstVector, srcVector);
                    Sse.Store(pDstCurrent, dstVector);
 
                    pSrcCurrent += 4;
                    pDstCurrent += 4;
                }
 
                while (pDstCurrent < pEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
 
                    srcVector = Sse.MultiplyScalar(srcVector, scaleVector128);
                    dstVector = Sse.AddScalar(dstVector, srcVector);
                    Sse.StoreScalar(pDstCurrent, dstVector);
 
                    pSrcCurrent++;
                    pDstCurrent++;
                }
            }
        }
 
        public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<float> dst, Span<float> result, int count)
        {
            Contracts.Assert(count <= src.Length);
            Contracts.Assert(count <= dst.Length);
            Contracts.Assert(count <= result.Length);
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            fixed (float* pres = &MemoryMarshal.GetReference(result))
            {
                float* pResEnd = pres + count;
                float* pSrcCurrent = psrc;
                float* pDstCurrent = pdst;
                float* pResCurrent = pres;
 
                Vector256<float> scaleVector256 = Vector256.Create(scale);
 
                while (pResCurrent + 8 <= pResEnd)
                {
                    Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
                    dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
                    Avx.Store(pResCurrent, dstVector);
 
                    pSrcCurrent += 8;
                    pDstCurrent += 8;
                    pResCurrent += 8;
                }
 
                Vector128<float> scaleVector128 = Vector128.Create(scale);
 
                if (pResCurrent + 4 <= pResEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
                    srcVector = Sse.Multiply(srcVector, scaleVector128);
                    dstVector = Sse.Add(dstVector, srcVector);
                    Sse.Store(pResCurrent, dstVector);
 
                    pSrcCurrent += 4;
                    pDstCurrent += 4;
                    pResCurrent += 4;
                }
 
                while (pResCurrent < pResEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
                    srcVector = Sse.MultiplyScalar(srcVector, scaleVector128);
                    dstVector = Sse.AddScalar(dstVector, srcVector);
                    Sse.StoreScalar(pResCurrent, dstVector);
 
                    pSrcCurrent++;
                    pDstCurrent++;
                    pResCurrent++;
                }
            }
        }
 
        public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
        {
            Contracts.Assert(count <= src.Length);
            Contracts.Assert(count <= dst.Length);
            Contracts.Assert(count <= idx.Length);
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (int* pidx = &MemoryMarshal.GetReference(idx))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pSrcCurrent = psrc;
                int* pIdxCurrent = pidx;
                float* pDstCurrent = pdst;
                int* pEnd = pidx + count;
 
                Vector256<float> scaleVector256 = Vector256.Create(scale);
 
                while (pIdxCurrent + 8 <= pEnd)
                {
                    Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
                    dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
                    Store8(in dstVector, pDstCurrent, pIdxCurrent);
 
                    pIdxCurrent += 8;
                    pSrcCurrent += 8;
                }
 
                Vector128<float> scaleVector128 = Vector128.Create(scale);
 
                if (pIdxCurrent + 4 <= pEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    Vector128<float> dstVector = SseIntrinsics.Load4(pDstCurrent, pIdxCurrent);
 
                    srcVector = Sse.Multiply(srcVector, scaleVector128);
                    dstVector = Sse.Add(dstVector, srcVector);
                    SseIntrinsics.Store4(in dstVector, pDstCurrent, pIdxCurrent);
 
                    pIdxCurrent += 4;
                    pSrcCurrent += 4;
                }
 
                while (pIdxCurrent < pEnd)
                {
                    pDstCurrent[*pIdxCurrent] += scale * (*pSrcCurrent);
 
                    pIdxCurrent++;
                    pSrcCurrent++;
                }
            }
        }
 
        public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int count)
        {
            Contracts.Assert(count <= src.Length);
            Contracts.Assert(count <= dst.Length);
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pSrcCurrent = psrc;
                float* pDstCurrent = pdst;
                float* pEnd = psrc + count;
 
                while (pSrcCurrent + 8 <= pEnd)
                {
                    Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
                    Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
 
                    Vector256<float> result = Avx.Add(srcVector, dstVector);
                    Avx.Store(pDstCurrent, result);
 
                    pSrcCurrent += 8;
                    pDstCurrent += 8;
                }
 
                if (pSrcCurrent + 4 <= pEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
 
                    Vector128<float> result = Sse.Add(srcVector, dstVector);
                    Sse.Store(pDstCurrent, result);
 
                    pSrcCurrent += 4;
                    pDstCurrent += 4;
                }
 
                while (pSrcCurrent < pEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
 
                    Vector128<float> result = Sse.AddScalar(srcVector, dstVector);
                    Sse.StoreScalar(pDstCurrent, result);
 
                    pSrcCurrent++;
                    pDstCurrent++;
                }
            }
        }
 
        public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
        {
            Contracts.Assert(count <= src.Length);
            Contracts.Assert(count <= dst.Length);
            Contracts.Assert(count <= idx.Length);
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (int* pidx = &MemoryMarshal.GetReference(idx))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pSrcCurrent = psrc;
                int* pIdxCurrent = pidx;
                float* pDstCurrent = pdst;
                int* pEnd = pidx + count;
 
                while (pIdxCurrent + 8 <= pEnd)
                {
                    Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
                    Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
 
                    dstVector = Avx.Add(dstVector, srcVector);
                    Store8(in dstVector, pDstCurrent, pIdxCurrent);
 
                    pIdxCurrent += 8;
                    pSrcCurrent += 8;
                }
 
                if (pIdxCurrent + 4 <= pEnd)
                {
                    Vector128<float> dstVector = SseIntrinsics.Load4(pDstCurrent, pIdxCurrent);
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
 
                    dstVector = Sse.Add(dstVector, srcVector);
                    SseIntrinsics.Store4(in dstVector, pDstCurrent, pIdxCurrent);
 
                    pIdxCurrent += 4;
                    pSrcCurrent += 4;
                }
 
                while (pIdxCurrent < pEnd)
                {
                    pDstCurrent[*pIdxCurrent] += *pSrcCurrent;
 
                    pIdxCurrent++;
                    pSrcCurrent++;
                }
            }
        }
 
        public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan<float> src2, Span<float> dst, int count)
        {
            Contracts.Assert(count <= src1.Length);
            Contracts.Assert(count <= src2.Length);
            Contracts.Assert(count <= dst.Length);
            fixed (float* psrc1 = &MemoryMarshal.GetReference(src1))
            fixed (float* psrc2 = &MemoryMarshal.GetReference(src2))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pSrc1Current = psrc1;
                float* pSrc2Current = psrc2;
                float* pDstCurrent = pdst;
                float* pEnd = pdst + count;
 
                while (pDstCurrent + 8 <= pEnd)
                {
                    Vector256<float> src1Vector = Avx.LoadVector256(pSrc1Current);
                    Vector256<float> src2Vector = Avx.LoadVector256(pSrc2Current);
                    src2Vector = Avx.Multiply(src1Vector, src2Vector);
                    Avx.Store(pDstCurrent, src2Vector);
 
                    pSrc1Current += 8;
                    pSrc2Current += 8;
                    pDstCurrent += 8;
                }
 
                if (pDstCurrent + 4 <= pEnd)
                {
                    Vector128<float> src1Vector = Sse.LoadVector128(pSrc1Current);
                    Vector128<float> src2Vector = Sse.LoadVector128(pSrc2Current);
                    src2Vector = Sse.Multiply(src1Vector, src2Vector);
                    Sse.Store(pDstCurrent, src2Vector);
 
                    pSrc1Current += 4;
                    pSrc2Current += 4;
                    pDstCurrent += 4;
                }
 
                while (pDstCurrent < pEnd)
                {
                    Vector128<float> src1Vector = Sse.LoadScalarVector128(pSrc1Current);
                    Vector128<float> src2Vector = Sse.LoadScalarVector128(pSrc2Current);
                    src2Vector = Sse.MultiplyScalar(src1Vector, src2Vector);
                    Sse.StoreScalar(pDstCurrent, src2Vector);
 
                    pSrc1Current++;
                    pSrc2Current++;
                    pDstCurrent++;
                }
            }
        }
 
        public static unsafe float Sum(ReadOnlySpan<float> src)
        {
            fixed (float* pSrc = &MemoryMarshal.GetReference(src))
            fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
            fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
            {
                float* pValues = pSrc;
                int length = src.Length;
                Vector256<float> result = Vector256<float>.Zero;
 
                nuint address = (nuint)(pValues);
                int misalignment = (int)(address % 32);
                int remainder = 0;
 
                if ((misalignment & 3) != 0)
                {
                    // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
 
                    remainder = length % 8;
 
                    for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 8)
                    {
                        result = Avx.Add(result, Avx.LoadVector256(pValues));
                    }
                }
                else
                {
                    if (misalignment != 0)
                    {
                        // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
                        // masking any elements that will be included in the first aligned read
 
                        misalignment >>= 2;
                        misalignment = 8 - misalignment;
 
                        Vector256<float> mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
                        Vector256<float> temp = Avx.And(mask, Avx.LoadVector256(pValues));
                        result = Avx.Add(result, temp);
 
                        pValues += misalignment;
                        length -= misalignment;
                    }
 
                    if (length > 7)
                    {
                        // Handle all the 256-bit blocks that we can now that we have offset to an aligned address
 
                        remainder = length % 8;
 
                        for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 8)
                        {
                            // The JIT will only fold away unaligned loads due to the semantics behind
                            // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
                            // modern hardware has unaligned loads that are as fast as aligned loads,
                            // when it doesn't cross a cache-line/page boundary, we will just assert
                            // that the alignment is correct and allow for the more-efficient codegen.
 
                            Contracts.Assert(((nuint)(pValues) % 32) == 0);
                            result = Avx.Add(result, Avx.LoadVector256(pValues));
                        }
                    }
                    else
                    {
                        // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
                        // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
                        // unaligned loads where we mask the input each time.
                        remainder = length;
                    }
                }
 
                if (remainder != 0)
                {
                    // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
                    // unaligned load will read to the end of the array and then mask out any elements already processed
 
                    pValues -= (8 - remainder);
 
                    Vector256<float> mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
                    Vector256<float> temp = Avx.And(mask, Avx.LoadVector256(pValues));
                    result = Avx.Add(result, temp);
                }
 
                // Sum all the elements together and return the result
                result = VectorSum256(in result);
                return Sse.AddScalar(result.GetLower(), GetHigh(result)).ToScalar();
            }
        }
 
        public static unsafe float SumSqU(ReadOnlySpan<float> src)
        {
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            {
                float* pSrcEnd = psrc + src.Length;
                float* pSrcCurrent = psrc;
 
                Vector256<float> result256 = Vector256<float>.Zero;
 
                while (pSrcCurrent + 8 <= pSrcEnd)
                {
                    Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
                    result256 = MultiplyAdd(srcVector, srcVector, result256);
 
                    pSrcCurrent += 8;
                }
 
                result256 = VectorSum256(in result256);
                Vector128<float> resultPadded = Sse.AddScalar(result256.GetLower(), GetHigh(result256));
 
                Vector128<float> result128 = Vector128<float>.Zero;
 
                if (pSrcCurrent + 4 <= pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    result128 = Sse.Add(result128, Sse.Multiply(srcVector, srcVector));
 
                    pSrcCurrent += 4;
                }
 
                result128 = SseIntrinsics.VectorSum128(in result128);
 
                while (pSrcCurrent < pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, srcVector));
 
                    pSrcCurrent++;
                }
 
                return Sse.AddScalar(result128, resultPadded).ToScalar();
            }
        }
 
        public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src)
        {
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            {
                float* pSrcEnd = psrc + src.Length;
                float* pSrcCurrent = psrc;
 
                Vector256<float> result256 = Vector256<float>.Zero;
                Vector256<float> meanVector256 = Vector256.Create(mean);
 
                while (pSrcCurrent + 8 <= pSrcEnd)
                {
                    Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
                    srcVector = Avx.Subtract(srcVector, meanVector256);
                    result256 = MultiplyAdd(srcVector, srcVector, result256);
                    pSrcCurrent += 8;
                }
 
                result256 = VectorSum256(in result256);
                Vector128<float> resultPadded = Sse.AddScalar(result256.GetLower(), GetHigh(result256));
 
                Vector128<float> result128 = Vector128<float>.Zero;
                Vector128<float> meanVector128 = Vector128.Create(mean);
 
                if (pSrcCurrent + 4 <= pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    srcVector = Sse.Subtract(srcVector, meanVector128);
                    result128 = Sse.Add(result128, Sse.Multiply(srcVector, srcVector));
 
                    pSrcCurrent += 4;
                }
 
                result128 = SseIntrinsics.VectorSum128(in result128);
 
                while (pSrcCurrent < pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    srcVector = Sse.SubtractScalar(srcVector, meanVector128);
                    result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, srcVector));
 
                    pSrcCurrent++;
                }
 
                return Sse.AddScalar(result128, resultPadded).ToScalar();
            }
        }
 
        public static unsafe float SumAbsU(ReadOnlySpan<float> src)
        {
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            {
                float* pSrcEnd = psrc + src.Length;
                float* pSrcCurrent = psrc;
 
                Vector256<float> result256 = Vector256<float>.Zero;
 
                while (pSrcCurrent + 8 <= pSrcEnd)
                {
                    Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
                    result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256));
 
                    pSrcCurrent += 8;
                }
 
                result256 = VectorSum256(in result256);
                Vector128<float> resultPadded = Sse.AddScalar(result256.GetLower(), GetHigh(result256));
 
                Vector128<float> result128 = Vector128<float>.Zero;
 
                if (pSrcCurrent + 4 <= pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    result128 = Sse.Add(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
 
                    pSrcCurrent += 4;
                }
 
                result128 = SseIntrinsics.VectorSum128(in result128);
 
                while (pSrcCurrent < pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    result128 = Sse.AddScalar(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
 
                    pSrcCurrent++;
                }
 
                return Sse.AddScalar(result128, resultPadded).ToScalar();
            }
        }
 
        public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src)
        {
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            {
                float* pSrcEnd = psrc + src.Length;
                float* pSrcCurrent = psrc;
 
                Vector256<float> result256 = Vector256<float>.Zero;
                Vector256<float> meanVector256 = Vector256.Create(mean);
 
                while (pSrcCurrent + 8 <= pSrcEnd)
                {
                    Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
                    srcVector = Avx.Subtract(srcVector, meanVector256);
                    result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256));
 
                    pSrcCurrent += 8;
                }
 
                result256 = VectorSum256(in result256);
                Vector128<float> resultPadded = Sse.AddScalar(result256.GetLower(), GetHigh(result256));
 
                Vector128<float> result128 = Vector128<float>.Zero;
                Vector128<float> meanVector128 = Vector128.Create(mean);
 
                if (pSrcCurrent + 4 <= pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    srcVector = Sse.Subtract(srcVector, meanVector128);
                    result128 = Sse.Add(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
 
                    pSrcCurrent += 4;
                }
 
                result128 = SseIntrinsics.VectorSum128(in result128);
 
                while (pSrcCurrent < pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    srcVector = Sse.SubtractScalar(srcVector, meanVector128);
                    result128 = Sse.AddScalar(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
 
                    pSrcCurrent++;
                }
 
                return Sse.AddScalar(result128, resultPadded).ToScalar();
            }
        }
 
        public static unsafe float MaxAbsU(ReadOnlySpan<float> src)
        {
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            {
                float* pSrcEnd = psrc + src.Length;
                float* pSrcCurrent = psrc;
 
                Vector256<float> result256 = Vector256<float>.Zero;
 
                while (pSrcCurrent + 8 <= pSrcEnd)
                {
                    Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
                    result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256));
 
                    pSrcCurrent += 8;
                }
 
                result256 = VectorMax256(in result256);
                Vector128<float> resultPadded = Sse.MaxScalar(result256.GetLower(), GetHigh(result256));
 
                Vector128<float> result128 = Vector128<float>.Zero;
 
                if (pSrcCurrent + 4 <= pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    result128 = Sse.Max(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
 
                    pSrcCurrent += 4;
                }
 
                result128 = SseIntrinsics.VectorMax128(in result128);
 
                while (pSrcCurrent < pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    result128 = Sse.MaxScalar(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
 
                    pSrcCurrent++;
                }
 
                return Sse.MaxScalar(result128, resultPadded).ToScalar();
            }
        }
 
        public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
        {
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            {
                float* pSrcEnd = psrc + src.Length;
                float* pSrcCurrent = psrc;
 
                Vector256<float> result256 = Vector256<float>.Zero;
                Vector256<float> meanVector256 = Vector256.Create(mean);
 
                while (pSrcCurrent + 8 <= pSrcEnd)
                {
                    Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
                    srcVector = Avx.Subtract(srcVector, meanVector256);
                    result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256));
 
                    pSrcCurrent += 8;
                }
 
                result256 = VectorMax256(in result256);
                Vector128<float> resultPadded = Sse.MaxScalar(result256.GetLower(), GetHigh(result256));
 
                Vector128<float> result128 = Vector128<float>.Zero;
                Vector128<float> meanVector128 = Vector128.Create(mean);
 
                if (pSrcCurrent + 4 <= pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    srcVector = Sse.Subtract(srcVector, meanVector128);
                    result128 = Sse.Max(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
 
                    pSrcCurrent += 4;
                }
 
                result128 = SseIntrinsics.VectorMax128(in result128);
 
                while (pSrcCurrent < pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    srcVector = Sse.SubtractScalar(srcVector, meanVector128);
                    result128 = Sse.MaxScalar(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
 
                    pSrcCurrent++;
                }
 
                return Sse.MaxScalar(result128, resultPadded).ToScalar();
            }
        }
 
        public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
        {
            Contracts.Assert(count <= src.Length);
            Contracts.Assert(count <= dst.Length);
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pSrcCurrent = psrc;
                float* pDstCurrent = pdst;
                float* pSrcEnd = psrc + count;
 
                Vector256<float> result256 = Vector256<float>.Zero;
 
                while (pSrcCurrent + 8 <= pSrcEnd)
                {
                    Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
                    result256 = MultiplyAdd(pSrcCurrent, dstVector, result256);
                    pSrcCurrent += 8;
                    pDstCurrent += 8;
                }
 
                result256 = VectorSum256(in result256);
                Vector128<float> resultPadded = Sse.AddScalar(result256.GetLower(), GetHigh(result256));
 
                Vector128<float> result128 = Vector128<float>.Zero;
 
                if (pSrcCurrent + 4 <= pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
                    Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
 
                    result128 = Sse.Add(result128, Sse.Multiply(srcVector, dstVector));
 
                    pSrcCurrent += 4;
                    pDstCurrent += 4;
                }
 
                result128 = SseIntrinsics.VectorSum128(in result128);
 
                while (pSrcCurrent < pSrcEnd)
                {
                    Vector128<float> srcVector = Sse.LoadScalarVector128(pSrcCurrent);
                    Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
 
                    result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, dstVector));
 
                    pSrcCurrent++;
                    pDstCurrent++;
                }
 
                return Sse.AddScalar(result128, resultPadded).ToScalar();
            }
        }
 
        public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, ReadOnlySpan<int> idx, int count)
        {
            Contracts.Assert(count <= src.Length);
            Contracts.Assert(count <= dst.Length);
            Contracts.Assert(count <= idx.Length);
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            fixed (int* pidx = &MemoryMarshal.GetReference(idx))
            {
                float* pSrcCurrent = psrc;
                float* pDstCurrent = pdst;
                int* pIdxCurrent = pidx;
                int* pIdxEnd = pidx + count;
 
                Vector256<float> result256 = Vector256<float>.Zero;
 
                while (pIdxCurrent + 8 <= pIdxEnd)
                {
                    Vector256<float> srcVector = Load8(pSrcCurrent, pIdxCurrent);
                    result256 = MultiplyAdd(pDstCurrent, srcVector, result256);
                    pIdxCurrent += 8;
                    pDstCurrent += 8;
                }
 
                result256 = VectorSum256(in result256);
                Vector128<float> resultPadded = Sse.AddScalar(result256.GetLower(), GetHigh(result256));
 
                Vector128<float> result128 = Vector128<float>.Zero;
 
                if (pIdxCurrent + 4 <= pIdxEnd)
                {
                    Vector128<float> srcVector = SseIntrinsics.Load4(pSrcCurrent, pIdxCurrent);
                    Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
 
                    result128 = Sse.Add(result128, Sse.Multiply(srcVector, dstVector));
 
                    pIdxCurrent += 4;
                    pDstCurrent += 4;
                }
 
                result128 = SseIntrinsics.VectorSum128(in result128);
 
                while (pIdxCurrent < pIdxEnd)
                {
                    Vector128<float> srcVector = SseIntrinsics.Load1(pSrcCurrent, pIdxCurrent);
                    Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
 
                    result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, dstVector));
 
                    pIdxCurrent++;
                    pDstCurrent++;
                }
 
                return Sse.AddScalar(result128, resultPadded).ToScalar();
            }
        }
 
        public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
        {
            Contracts.Assert(count <= src.Length);
            Contracts.Assert(count <= dst.Length);
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (float* pdst = &MemoryMarshal.GetReference(dst))
            {
                float* pSrcCurrent = psrc;
                float* pDstCurrent = pdst;
                float* pSrcEnd = psrc + count;
 
                Vector256<float> sqDistanceVector256 = Vector256<float>.Zero;
 
                while (pSrcCurrent + 8 <= pSrcEnd)
                {
                    Vector256<float> distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent),
                                                                    Avx.LoadVector256(pDstCurrent));
                    sqDistanceVector256 = MultiplyAdd(distanceVector, distanceVector, sqDistanceVector256);
                    pSrcCurrent += 8;
                    pDstCurrent += 8;
                }
 
                sqDistanceVector256 = VectorSum256(in sqDistanceVector256);
                Vector128<float> sqDistanceVectorPadded = Sse.AddScalar(sqDistanceVector256.GetLower(), GetHigh(sqDistanceVector256));
 
                Vector128<float> sqDistanceVector128 = Vector128<float>.Zero;
 
                if (pSrcCurrent + 4 <= pSrcEnd)
                {
                    Vector128<float> distanceVector = Sse.Subtract(Sse.LoadVector128(pSrcCurrent),
                                                                    Sse.LoadVector128(pDstCurrent));
                    sqDistanceVector128 = Sse.Add(sqDistanceVector128,
                                                Sse.Multiply(distanceVector, distanceVector));
 
                    pSrcCurrent += 4;
                    pDstCurrent += 4;
                }
 
                sqDistanceVector128 = SseIntrinsics.VectorSum128(in sqDistanceVector128);
 
                float norm = Sse.AddScalar(sqDistanceVector128, sqDistanceVectorPadded).ToScalar();
                while (pSrcCurrent < pSrcEnd)
                {
                    float distance = (*pSrcCurrent) - (*pDstCurrent);
                    norm += distance * distance;
 
                    pSrcCurrent++;
                    pDstCurrent++;
                }
 
                return norm;
            }
        }
 
        public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlySpan<float> src, float threshold, Span<float> v, Span<float> w)
        {
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (float* pdst1 = &MemoryMarshal.GetReference(v))
            fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
            {
                float* pSrcEnd = psrc + count;
                float* pSrcCurrent = psrc;
                float* pDst1Current = pdst1;
                float* pDst2Current = pdst2;
 
                Vector256<float> xPrimal256 = Vector256.Create(primalUpdate);
                Vector256<float> xThreshold256 = Vector256.Create(threshold);
 
                while (pSrcCurrent + 8 <= pSrcEnd)
                {
                    Vector256<float> xDst1 = Avx.LoadVector256(pDst1Current);
                    xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1);
                    Vector256<float> xDst2 = GetNewDst256(xDst1, xThreshold256);
 
                    Avx.Store(pDst1Current, xDst1);
                    Avx.Store(pDst2Current, xDst2);
 
                    pSrcCurrent += 8;
                    pDst1Current += 8;
                    pDst2Current += 8;
                }
 
                Vector128<float> xPrimal128 = Vector128.Create(primalUpdate);
                Vector128<float> xThreshold128 = Vector128.Create(threshold);
 
                if (pSrcCurrent + 4 <= pSrcEnd)
                {
                    Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);
 
                    Vector128<float> xDst1 = Sse.LoadVector128(pDst1Current);
                    xDst1 = Sse.Add(xDst1, Sse.Multiply(xSrc, xPrimal128));
                    Vector128<float> xDst2 = SseIntrinsics.GetNewDst128(xDst1, xThreshold128);
 
                    Sse.Store(pDst1Current, xDst1);
                    Sse.Store(pDst2Current, xDst2);
 
                    pSrcCurrent += 4;
                    pDst1Current += 4;
                    pDst2Current += 4;
                }
 
                while (pSrcCurrent < pSrcEnd)
                {
                    *pDst1Current += (*pSrcCurrent) * primalUpdate;
                    float dst1 = *pDst1Current;
                    *pDst2Current = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0;
 
                    pSrcCurrent++;
                    pDst1Current++;
                    pDst2Current++;
                }
            }
        }
 
        public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnlySpan<float> src, ReadOnlySpan<int> indices, float threshold, Span<float> v, Span<float> w)
        {
            fixed (float* psrc = &MemoryMarshal.GetReference(src))
            fixed (int* pidx = &MemoryMarshal.GetReference(indices))
            fixed (float* pdst1 = &MemoryMarshal.GetReference(v))
            fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
            {
                int* pIdxEnd = pidx + count;
                float* pSrcCurrent = psrc;
                int* pIdxCurrent = pidx;
 
                Vector256<float> xPrimal256 = Vector256.Create(primalUpdate);
                Vector256<float> xThreshold = Vector256.Create(threshold);
 
                while (pIdxCurrent + 8 <= pIdxEnd)
                {
                    Vector256<float> xDst1 = Load8(pdst1, pIdxCurrent);
                    xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1);
                    Vector256<float> xDst2 = GetNewDst256(xDst1, xThreshold);
 
                    Store8(in xDst1, pdst1, pIdxCurrent);
                    Store8(in xDst2, pdst2, pIdxCurrent);
 
                    pIdxCurrent += 8;
                    pSrcCurrent += 8;
                }
 
                Vector128<float> xPrimal128 = Vector128.Create(primalUpdate);
                Vector128<float> xThreshold128 = Vector128.Create(threshold);
 
                if (pIdxCurrent + 4 <= pIdxEnd)
                {
                    Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);
 
                    Vector128<float> xDst1 = SseIntrinsics.Load4(pdst1, pIdxCurrent);
                    xDst1 = Sse.Add(xDst1, Sse.Multiply(xSrc, xPrimal128));
                    Vector128<float> xDst2 = SseIntrinsics.GetNewDst128(xDst1, xThreshold128);
 
                    SseIntrinsics.Store4(in xDst1, pdst1, pIdxCurrent);
                    SseIntrinsics.Store4(in xDst2, pdst2, pIdxCurrent);
 
                    pIdxCurrent += 4;
                    pSrcCurrent += 4;
                }
 
                while (pIdxCurrent < pIdxEnd)
                {
                    int index = *pIdxCurrent;
                    pdst1[index] += (*pSrcCurrent) * primalUpdate;
                    float dst1 = pdst1[index];
                    pdst2[index] = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0;
 
                    pIdxCurrent++;
                    pSrcCurrent++;
                }
            }
        }
    }
}