// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; namespace System.Numerics.Tensors { public static unsafe partial class TensorPrimitives { /// <summary>Operator that takes two input values and returns a single value.</summary> private interface IBinaryOperator<T> { static abstract bool Vectorizable { get; } static abstract T Invoke(T x, T y); static abstract Vector128<T> Invoke(Vector128<T> x, Vector128<T> y); static abstract Vector256<T> Invoke(Vector256<T> x, Vector256<T> y); static abstract Vector512<T> Invoke(Vector512<T> x, Vector512<T> y); } /// <summary> /// Performs an element-wise operation on <paramref name="x"/> and <paramref name="y"/>, /// and writes the results to <paramref name="destination"/>. /// </summary> /// <typeparam name="T">The element type.</typeparam> /// <typeparam name="TBinaryOperator"> /// Specifies the operation to perform on each element loaded from <paramref name="x"/> with <paramref name="y"/>. /// </typeparam> private static void InvokeScalarSpanIntoSpan<T, TBinaryOperator>( T x, ReadOnlySpan<T> y, Span<T> destination) where TBinaryOperator : struct, IBinaryOperator<T> => InvokeSpanScalarIntoSpan<T, IdentityOperator<T>, SwappedBinaryOperator<TBinaryOperator, T>>(y, x, destination); /// <summary> /// Performs an element-wise operation on <paramref name="x"/> and <paramref name="y"/>, /// and writes the results to <paramref name="destination"/>. /// </summary> /// <typeparam name="T">The element type.</typeparam> /// <typeparam name="TBinaryOperator"> /// Specifies the operation to perform on each element loaded from <paramref name="x"/> with <paramref name="y"/>. /// </typeparam> private static void InvokeSpanScalarIntoSpan<T, TBinaryOperator>( ReadOnlySpan<T> x, T y, Span<T> destination) where TBinaryOperator : struct, IBinaryOperator<T> => InvokeSpanScalarIntoSpan<T, IdentityOperator<T>, TBinaryOperator>(x, y, destination); /// <summary> /// Performs an element-wise operation on <paramref name="x"/> and <paramref name="y"/>, /// and writes the results to <paramref name="destination"/>. /// </summary> /// <typeparam name="T">The element type.</typeparam> /// <typeparam name="TBinaryOperator{T}"> /// Specifies the operation to perform on the pair-wise elements loaded from <paramref name="x"/> and <paramref name="y"/>. /// </typeparam> private static void InvokeSpanSpanIntoSpan<T, TBinaryOperator>( ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where TBinaryOperator : struct, IBinaryOperator<T> { if (x.Length != y.Length) { ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); } if (x.Length > destination.Length) { ThrowHelper.ThrowArgument_DestinationTooShort(); } ValidateInputOutputSpanNonOverlapping(x, destination); ValidateInputOutputSpanNonOverlapping(y, destination); // Since every branch has a cost and since that cost is // essentially lost for larger inputs, we do branches // in a way that allows us to have the minimum possible // for small sizes ref T xRef = ref MemoryMarshal.GetReference(x); ref T yRef = ref MemoryMarshal.GetReference(y); ref T dRef = ref MemoryMarshal.GetReference(destination); nuint remainder = (uint)x.Length; if (Vector512.IsHardwareAccelerated && Vector512<T>.IsSupported && TBinaryOperator.Vectorizable) { if (remainder >= (uint)Vector512<T>.Count) { Vectorized512(ref xRef, ref yRef, ref dRef, remainder); } else { // We have less than a vector and so we can only handle this as scalar. To do this // efficiently, we simply have a small jump table and fallthrough. So we get a simple // length check, single jump, and then linear execution. VectorizedSmall(ref xRef, ref yRef, ref dRef, remainder); } return; } if (Vector256.IsHardwareAccelerated && Vector256<T>.IsSupported && TBinaryOperator.Vectorizable) { if (remainder >= (uint)Vector256<T>.Count) { Vectorized256(ref xRef, ref yRef, ref dRef, remainder); } else { // We have less than a vector and so we can only handle this as scalar. To do this // efficiently, we simply have a small jump table and fallthrough. So we get a simple // length check, single jump, and then linear execution. VectorizedSmall(ref xRef, ref yRef, ref dRef, remainder); } return; } if (Vector128.IsHardwareAccelerated && Vector128<T>.IsSupported && TBinaryOperator.Vectorizable) { if (remainder >= (uint)Vector128<T>.Count) { Vectorized128(ref xRef, ref yRef, ref dRef, remainder); } else { // We have less than a vector and so we can only handle this as scalar. To do this // efficiently, we simply have a small jump table and fallthrough. So we get a simple // length check, single jump, and then linear execution. VectorizedSmall(ref xRef, ref yRef, ref dRef, remainder); } return; } // This is the software fallback when no acceleration is available // It requires no branches to hit SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] static void SoftwareFallback(ref T xRef, ref T yRef, ref T dRef, nuint length) { for (nuint i = 0; i < length; i++) { Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), Unsafe.Add(ref yRef, i)); } } static void Vectorized128(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data Vector128<T> beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef)); Vector128<T> end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count), Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128<T>.Count)); if (remainder > (uint)(Vector128<T>.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. fixed (T* px = &xRef) fixed (T* py = &yRef) fixed (T* pd = &dRef) { T* xPtr = px; T* yPtr = py; T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { // Compute by how many elements we're misaligned and adjust the pointers accordingly // // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. nuint misalignment = ((uint)sizeof(Vector128<T>) - ((nuint)dPtr % (uint)sizeof(Vector128<T>))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; dPtr += misalignment; Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128<T>)) == 0); remainder -= misalignment; } Vector128<T> vector1; Vector128<T> vector2; Vector128<T> vector3; Vector128<T> vector4; if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. while (remainder >= (uint)(Vector128<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 0)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 0))); vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 1)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 1))); vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 2)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 2))); vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 3)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 3))); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 0)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 1)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 2)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 4)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 4))); vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 5)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 5))); vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 6)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 6))); vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 7)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 7))); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 4)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 5)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 6)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector128<T>.Count * 8); yPtr += (uint)(Vector128<T>.Count * 8); dPtr += (uint)(Vector128<T>.Count * 8); remainder -= (uint)(Vector128<T>.Count * 8); } } else { while (remainder >= (uint)(Vector128<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 0)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 0))); vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 1)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 1))); vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 2)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 2))); vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 3)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 3))); vector1.Store(dPtr + (uint)(Vector128<T>.Count * 0)); vector2.Store(dPtr + (uint)(Vector128<T>.Count * 1)); vector3.Store(dPtr + (uint)(Vector128<T>.Count * 2)); vector4.Store(dPtr + (uint)(Vector128<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 4)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 4))); vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 5)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 5))); vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 6)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 6))); vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 7)), Vector128.Load(yPtr + (uint)(Vector128<T>.Count * 7))); vector1.Store(dPtr + (uint)(Vector128<T>.Count * 4)); vector2.Store(dPtr + (uint)(Vector128<T>.Count * 5)); vector3.Store(dPtr + (uint)(Vector128<T>.Count * 6)); vector4.Store(dPtr + (uint)(Vector128<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector128<T>.Count * 8); yPtr += (uint)(Vector128<T>.Count * 8); dPtr += (uint)(Vector128<T>.Count * 8); remainder -= (uint)(Vector128<T>.Count * 8); } } // Adjusting the refs here allows us to avoid pinning for very small inputs xRef = ref *xPtr; yRef = ref *yPtr; dRef = ref *dPtr; } } // Process the remaining [Count, Count * 8] elements via a jump table // // Unless the original length was an exact multiple of Count, then we'll // end up reprocessing a couple elements in case 1 for end. We'll also // potentially reprocess a few elements in case 0 for beg, to handle any // data before the first aligned address. nuint endIndex = remainder; remainder = (remainder + (uint)(Vector128<T>.Count - 1)) & (nuint)(-Vector128<T>.Count); switch (remainder / (uint)Vector128<T>.Count) { case 8: { Vector128<T> vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 8)), Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128<T>.Count * 8))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 8)); goto case 7; } case 7: { Vector128<T> vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 7)), Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128<T>.Count * 7))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 7)); goto case 6; } case 6: { Vector128<T> vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 6)), Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128<T>.Count * 6))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 6)); goto case 5; } case 5: { Vector128<T> vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 5)), Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128<T>.Count * 5))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 5)); goto case 4; } case 4: { Vector128<T> vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 4)), Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128<T>.Count * 4))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 4)); goto case 3; } case 3: { Vector128<T> vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 3)), Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128<T>.Count * 3))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 3)); goto case 2; } case 2: { Vector128<T> vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 2)), Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128<T>.Count * 2))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128<T>.Count); goto case 0; } case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); break; } } } static void Vectorized256(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data Vector256<T> beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef)); Vector256<T> end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count), Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256<T>.Count)); if (remainder > (uint)(Vector256<T>.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. fixed (T* px = &xRef) fixed (T* py = &yRef) fixed (T* pd = &dRef) { T* xPtr = px; T* yPtr = py; T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { // Compute by how many elements we're misaligned and adjust the pointers accordingly // // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. nuint misalignment = ((uint)sizeof(Vector256<T>) - ((nuint)dPtr % (uint)sizeof(Vector256<T>))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; dPtr += misalignment; Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256<T>)) == 0); remainder -= misalignment; } Vector256<T> vector1; Vector256<T> vector2; Vector256<T> vector3; Vector256<T> vector4; if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. while (remainder >= (uint)(Vector256<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 0)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 0))); vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 1)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 1))); vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 2)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 2))); vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 3)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 3))); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 0)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 1)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 2)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 4)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 4))); vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 5)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 5))); vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 6)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 6))); vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 7)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 7))); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 4)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 5)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 6)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector256<T>.Count * 8); yPtr += (uint)(Vector256<T>.Count * 8); dPtr += (uint)(Vector256<T>.Count * 8); remainder -= (uint)(Vector256<T>.Count * 8); } } else { while (remainder >= (uint)(Vector256<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 0)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 0))); vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 1)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 1))); vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 2)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 2))); vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 3)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 3))); vector1.Store(dPtr + (uint)(Vector256<T>.Count * 0)); vector2.Store(dPtr + (uint)(Vector256<T>.Count * 1)); vector3.Store(dPtr + (uint)(Vector256<T>.Count * 2)); vector4.Store(dPtr + (uint)(Vector256<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 4)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 4))); vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 5)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 5))); vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 6)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 6))); vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 7)), Vector256.Load(yPtr + (uint)(Vector256<T>.Count * 7))); vector1.Store(dPtr + (uint)(Vector256<T>.Count * 4)); vector2.Store(dPtr + (uint)(Vector256<T>.Count * 5)); vector3.Store(dPtr + (uint)(Vector256<T>.Count * 6)); vector4.Store(dPtr + (uint)(Vector256<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector256<T>.Count * 8); yPtr += (uint)(Vector256<T>.Count * 8); dPtr += (uint)(Vector256<T>.Count * 8); remainder -= (uint)(Vector256<T>.Count * 8); } } // Adjusting the refs here allows us to avoid pinning for very small inputs xRef = ref *xPtr; yRef = ref *yPtr; dRef = ref *dPtr; } } // Process the remaining [Count, Count * 8] elements via a jump table // // Unless the original length was an exact multiple of Count, then we'll // end up reprocessing a couple elements in case 1 for end. We'll also // potentially reprocess a few elements in case 0 for beg, to handle any // data before the first aligned address. nuint endIndex = remainder; remainder = (remainder + (uint)(Vector256<T>.Count - 1)) & (nuint)(-Vector256<T>.Count); switch (remainder / (uint)Vector256<T>.Count) { case 8: { Vector256<T> vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 8)), Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256<T>.Count * 8))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 8)); goto case 7; } case 7: { Vector256<T> vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 7)), Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256<T>.Count * 7))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 7)); goto case 6; } case 6: { Vector256<T> vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 6)), Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256<T>.Count * 6))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 6)); goto case 5; } case 5: { Vector256<T> vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 5)), Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256<T>.Count * 5))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 5)); goto case 4; } case 4: { Vector256<T> vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 4)), Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256<T>.Count * 4))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 4)); goto case 3; } case 3: { Vector256<T> vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 3)), Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256<T>.Count * 3))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 3)); goto case 2; } case 2: { Vector256<T> vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 2)), Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256<T>.Count * 2))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256<T>.Count); goto case 0; } case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); break; } } } static void Vectorized512(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data Vector512<T> beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), Vector512.LoadUnsafe(ref yRef)); Vector512<T> end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512<T>.Count), Vector512.LoadUnsafe(ref yRef, remainder - (uint)Vector512<T>.Count)); if (remainder > (uint)(Vector512<T>.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. fixed (T* px = &xRef) fixed (T* py = &yRef) fixed (T* pd = &dRef) { T* xPtr = px; T* yPtr = py; T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { // Compute by how many elements we're misaligned and adjust the pointers accordingly // // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. nuint misalignment = ((uint)sizeof(Vector512<T>) - ((nuint)dPtr % (uint)sizeof(Vector512<T>))) / (uint)sizeof(T); xPtr += misalignment; yPtr += misalignment; dPtr += misalignment; Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512<T>)) == 0); remainder -= misalignment; } Vector512<T> vector1; Vector512<T> vector2; Vector512<T> vector3; Vector512<T> vector4; if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. while (remainder >= (uint)(Vector512<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 0)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 0))); vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 1)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 1))); vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 2)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 2))); vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 3)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 3))); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 0)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 1)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 2)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 4)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 4))); vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 5)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 5))); vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 6)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 6))); vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 7)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 7))); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 4)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 5)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 6)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector512<T>.Count * 8); yPtr += (uint)(Vector512<T>.Count * 8); dPtr += (uint)(Vector512<T>.Count * 8); remainder -= (uint)(Vector512<T>.Count * 8); } } else { while (remainder >= (uint)(Vector512<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 0)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 0))); vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 1)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 1))); vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 2)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 2))); vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 3)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 3))); vector1.Store(dPtr + (uint)(Vector512<T>.Count * 0)); vector2.Store(dPtr + (uint)(Vector512<T>.Count * 1)); vector3.Store(dPtr + (uint)(Vector512<T>.Count * 2)); vector4.Store(dPtr + (uint)(Vector512<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 4)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 4))); vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 5)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 5))); vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 6)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 6))); vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 7)), Vector512.Load(yPtr + (uint)(Vector512<T>.Count * 7))); vector1.Store(dPtr + (uint)(Vector512<T>.Count * 4)); vector2.Store(dPtr + (uint)(Vector512<T>.Count * 5)); vector3.Store(dPtr + (uint)(Vector512<T>.Count * 6)); vector4.Store(dPtr + (uint)(Vector512<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector512<T>.Count * 8); yPtr += (uint)(Vector512<T>.Count * 8); dPtr += (uint)(Vector512<T>.Count * 8); remainder -= (uint)(Vector512<T>.Count * 8); } } // Adjusting the refs here allows us to avoid pinning for very small inputs xRef = ref *xPtr; yRef = ref *yPtr; dRef = ref *dPtr; } } // Process the remaining [Count, Count * 8] elements via a jump table // // Unless the original length was an exact multiple of Count, then we'll // end up reprocessing a couple elements in case 1 for end. We'll also // potentially reprocess a few elements in case 0 for beg, to handle any // data before the first aligned address. nuint endIndex = remainder; remainder = (remainder + (uint)(Vector512<T>.Count - 1)) & (nuint)(-Vector512<T>.Count); switch (remainder / (uint)Vector512<T>.Count) { case 8: { Vector512<T> vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 8)), Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512<T>.Count * 8))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 8)); goto case 7; } case 7: { Vector512<T> vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 7)), Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512<T>.Count * 7))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 7)); goto case 6; } case 6: { Vector512<T> vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 6)), Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512<T>.Count * 6))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 6)); goto case 5; } case 5: { Vector512<T> vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 5)), Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512<T>.Count * 5))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 5)); goto case 4; } case 4: { Vector512<T> vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 4)), Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512<T>.Count * 4))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 4)); goto case 3; } case 3: { Vector512<T> vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 3)), Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512<T>.Count * 3))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 3)); goto case 2; } case 2: { Vector512<T> vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 2)), Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512<T>.Count * 2))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512<T>.Count); goto case 0; } case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); break; } } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { if (sizeof(T) == 1) { VectorizedSmall1(ref xRef, ref yRef, ref dRef, remainder); } else if (sizeof(T) == 2) { VectorizedSmall2(ref xRef, ref yRef, ref dRef, remainder); } else if (sizeof(T) == 4) { VectorizedSmall4(ref xRef, ref yRef, ref dRef, remainder); } else { Debug.Assert(sizeof(T) == 8); VectorizedSmall8(ref xRef, ref yRef, ref dRef, remainder); } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall1(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { Debug.Assert(sizeof(T) == 1); switch (remainder) { // Two Vector256's worth of data, with at least one element overlapping. case 63: case 62: case 61: case 60: case 59: case 58: case 57: case 56: case 55: case 54: case 53: case 52: case 51: case 50: case 49: case 48: case 47: case 46: case 45: case 44: case 43: case 42: case 41: case 40: case 39: case 38: case 37: case 36: case 35: case 34: case 33: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef)); Vector256<T> end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count), Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256<T>.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256<T>.Count); break; } // One Vector256's worth of data. case 32: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; } // Two Vector128's worth of data, with at least one element overlapping. case 31: case 30: case 29: case 28: case 27: case 26: case 25: case 24: case 23: case 22: case 21: case 20: case 19: case 18: case 17: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef)); Vector128<T> end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count), Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128<T>.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128<T>.Count); break; } // One Vector128's worth of data. case 16: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; } // Cases that are smaller than a single vector. No SIMD; just jump to the length and fall through each // case to unroll the whole processing. case 15: Unsafe.Add(ref dRef, 14) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 14), Unsafe.Add(ref yRef, 14)); goto case 14; case 14: Unsafe.Add(ref dRef, 13) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 13), Unsafe.Add(ref yRef, 13)); goto case 13; case 13: Unsafe.Add(ref dRef, 12) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 12), Unsafe.Add(ref yRef, 12)); goto case 12; case 12: Unsafe.Add(ref dRef, 11) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 11), Unsafe.Add(ref yRef, 11)); goto case 11; case 11: Unsafe.Add(ref dRef, 10) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 10), Unsafe.Add(ref yRef, 10)); goto case 10; case 10: Unsafe.Add(ref dRef, 9) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 9), Unsafe.Add(ref yRef, 9)); goto case 9; case 9: Unsafe.Add(ref dRef, 8) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 8), Unsafe.Add(ref yRef, 8)); goto case 8; case 8: Unsafe.Add(ref dRef, 7) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 7), Unsafe.Add(ref yRef, 7)); goto case 7; case 7: Unsafe.Add(ref dRef, 6) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 6), Unsafe.Add(ref yRef, 6)); goto case 6; case 6: Unsafe.Add(ref dRef, 5) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 5), Unsafe.Add(ref yRef, 5)); goto case 5; case 5: Unsafe.Add(ref dRef, 4) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 4), Unsafe.Add(ref yRef, 4)); goto case 4; case 4: Unsafe.Add(ref dRef, 3) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 3), Unsafe.Add(ref yRef, 3)); goto case 3; case 3: Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), Unsafe.Add(ref yRef, 2)); goto case 2; case 2: Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), Unsafe.Add(ref yRef, 1)); goto case 1; case 1: dRef = TBinaryOperator.Invoke(xRef, yRef); goto case 0; case 0: break; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall2(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { Debug.Assert(sizeof(T) == 2); switch (remainder) { // Two Vector256's worth of data, with at least one element overlapping. case 31: case 30: case 29: case 28: case 27: case 26: case 25: case 24: case 23: case 22: case 21: case 20: case 19: case 18: case 17: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef)); Vector256<T> end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count), Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256<T>.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256<T>.Count); break; } // One Vector256's worth of data. case 16: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; } // Two Vector128's worth of data, with at least one element overlapping. case 15: case 14: case 13: case 12: case 11: case 10: case 9: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef)); Vector128<T> end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count), Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128<T>.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128<T>.Count); break; } // One Vector128's worth of data. case 8: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; } // Cases that are smaller than a single vector. No SIMD; just jump to the length and fall through each // case to unroll the whole processing. case 7: Unsafe.Add(ref dRef, 6) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 6), Unsafe.Add(ref yRef, 6)); goto case 6; case 6: Unsafe.Add(ref dRef, 5) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 5), Unsafe.Add(ref yRef, 5)); goto case 5; case 5: Unsafe.Add(ref dRef, 4) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 4), Unsafe.Add(ref yRef, 4)); goto case 4; case 4: Unsafe.Add(ref dRef, 3) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 3), Unsafe.Add(ref yRef, 3)); goto case 3; case 3: Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), Unsafe.Add(ref yRef, 2)); goto case 2; case 2: Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), Unsafe.Add(ref yRef, 1)); goto case 1; case 1: dRef = TBinaryOperator.Invoke(xRef, yRef); goto case 0; case 0: break; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall4(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { Debug.Assert(sizeof(T) == 4); switch (remainder) { case 15: case 14: case 13: case 12: case 11: case 10: case 9: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef)); Vector256<T> end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count), Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256<T>.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256<T>.Count); break; } case 8: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; } case 7: case 6: case 5: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef)); Vector128<T> end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count), Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128<T>.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128<T>.Count); break; } case 4: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; } case 3: { Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), Unsafe.Add(ref yRef, 2)); goto case 2; } case 2: { Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), Unsafe.Add(ref yRef, 1)); goto case 1; } case 1: { dRef = TBinaryOperator.Invoke(xRef, yRef); goto case 0; } case 0: { break; } } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall8(ref T xRef, ref T yRef, ref T dRef, nuint remainder) { Debug.Assert(sizeof(T) == 8); switch (remainder) { case 7: case 6: case 5: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef)); Vector256<T> end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count), Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256<T>.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256<T>.Count); break; } case 4: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), Vector256.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; } case 3: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef)); Vector128<T> end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count), Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128<T>.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128<T>.Count); break; } case 2: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), Vector128.LoadUnsafe(ref yRef)); beg.StoreUnsafe(ref dRef); break; } case 1: { dRef = TBinaryOperator.Invoke(xRef, yRef); goto case 0; } case 0: { break; } } } } /// <summary> /// Performs an element-wise operation on <paramref name="x"/> and <paramref name="y"/>, /// and writes the results to <paramref name="destination"/>. /// </summary> /// <typeparam name="T">The element type.</typeparam> /// <typeparam name="TTransformOperator"> /// Specifies the operation to perform on each element loaded from <paramref name="x"/>. /// It is not used with <paramref name="y"/>. /// </typeparam> /// <typeparam name="TBinaryOperator"> /// Specifies the operation to perform on the transformed value from <paramref name="x"/> with <paramref name="y"/>. /// </typeparam> private static void InvokeSpanScalarIntoSpan<T, TTransformOperator, TBinaryOperator>( ReadOnlySpan<T> x, T y, Span<T> destination) where TTransformOperator : struct, IUnaryOperator<T, T> where TBinaryOperator : struct, IBinaryOperator<T> { if (x.Length > destination.Length) { ThrowHelper.ThrowArgument_DestinationTooShort(); } ValidateInputOutputSpanNonOverlapping(x, destination); // Since every branch has a cost and since that cost is // essentially lost for larger inputs, we do branches // in a way that allows us to have the minimum possible // for small sizes ref T xRef = ref MemoryMarshal.GetReference(x); ref T dRef = ref MemoryMarshal.GetReference(destination); nuint remainder = (uint)x.Length; if (Vector512.IsHardwareAccelerated && Vector512<T>.IsSupported && TTransformOperator.Vectorizable && TBinaryOperator.Vectorizable) { if (remainder >= (uint)Vector512<T>.Count) { Vectorized512(ref xRef, y, ref dRef, remainder); } else { // We have less than a vector and so we can only handle this as scalar. To do this // efficiently, we simply have a small jump table and fallthrough. So we get a simple // length check, single jump, and then linear execution. VectorizedSmall(ref xRef, y, ref dRef, remainder); } return; } if (Vector256.IsHardwareAccelerated && Vector256<T>.IsSupported && TTransformOperator.Vectorizable && TBinaryOperator.Vectorizable) { if (remainder >= (uint)Vector256<T>.Count) { Vectorized256(ref xRef, y, ref dRef, remainder); } else { // We have less than a vector and so we can only handle this as scalar. To do this // efficiently, we simply have a small jump table and fallthrough. So we get a simple // length check, single jump, and then linear execution. VectorizedSmall(ref xRef, y, ref dRef, remainder); } return; } if (Vector128.IsHardwareAccelerated && Vector128<T>.IsSupported && TTransformOperator.Vectorizable && TBinaryOperator.Vectorizable) { if (remainder >= (uint)Vector128<T>.Count) { Vectorized128(ref xRef, y, ref dRef, remainder); } else { // We have less than a vector and so we can only handle this as scalar. To do this // efficiently, we simply have a small jump table and fallthrough. So we get a simple // length check, single jump, and then linear execution. VectorizedSmall(ref xRef, y, ref dRef, remainder); } return; } // This is the software fallback when no acceleration is available // It requires no branches to hit SoftwareFallback(ref xRef, y, ref dRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] static void SoftwareFallback(ref T xRef, T y, ref T dRef, nuint length) { for (nuint i = 0; i < length; i++) { Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, i)), y); } } static void Vectorized128(ref T xRef, T y, ref T dRef, nuint remainder) { ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data Vector128<T> yVec = Vector128.Create(y); Vector128<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), yVec); Vector128<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count)), yVec); if (remainder > (uint)(Vector128<T>.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. fixed (T* px = &xRef) fixed (T* pd = &dRef) { T* xPtr = px; T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { // Compute by how many elements we're misaligned and adjust the pointers accordingly // // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. nuint misalignment = ((uint)sizeof(Vector128<T>) - ((nuint)dPtr % (uint)sizeof(Vector128<T>))) / (uint)sizeof(T); xPtr += misalignment; dPtr += misalignment; Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128<T>)) == 0); remainder -= misalignment; } Vector128<T> vector1; Vector128<T> vector2; Vector128<T> vector3; Vector128<T> vector4; if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. while (remainder >= (uint)(Vector128<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 0))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 1))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 2))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 3))), yVec); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 0)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 1)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 2)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 4))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 5))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 6))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 7))), yVec); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 4)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 5)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 6)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector128<T>.Count * 8); dPtr += (uint)(Vector128<T>.Count * 8); remainder -= (uint)(Vector128<T>.Count * 8); } } else { while (remainder >= (uint)(Vector128<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 0))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 1))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 2))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 3))), yVec); vector1.Store(dPtr + (uint)(Vector128<T>.Count * 0)); vector2.Store(dPtr + (uint)(Vector128<T>.Count * 1)); vector3.Store(dPtr + (uint)(Vector128<T>.Count * 2)); vector4.Store(dPtr + (uint)(Vector128<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 4))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 5))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 6))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128<T>.Count * 7))), yVec); vector1.Store(dPtr + (uint)(Vector128<T>.Count * 4)); vector2.Store(dPtr + (uint)(Vector128<T>.Count * 5)); vector3.Store(dPtr + (uint)(Vector128<T>.Count * 6)); vector4.Store(dPtr + (uint)(Vector128<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector128<T>.Count * 8); dPtr += (uint)(Vector128<T>.Count * 8); remainder -= (uint)(Vector128<T>.Count * 8); } } // Adjusting the refs here allows us to avoid pinning for very small inputs xRef = ref *xPtr; dRef = ref *dPtr; } } // Process the remaining [Count, Count * 8] elements via a jump table // // Unless the original length was an exact multiple of Count, then we'll // end up reprocessing a couple elements in case 1 for end. We'll also // potentially reprocess a few elements in case 0 for beg, to handle any // data before the first aligned address. nuint endIndex = remainder; remainder = (remainder + (uint)(Vector128<T>.Count - 1)) & (nuint)(-Vector128<T>.Count); switch (remainder / (uint)Vector128<T>.Count) { case 8: { Vector128<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 8))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 8)); goto case 7; } case 7: { Vector128<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 7))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 7)); goto case 6; } case 6: { Vector128<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 6))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 6)); goto case 5; } case 5: { Vector128<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 5))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 5)); goto case 4; } case 4: { Vector128<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 4))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 4)); goto case 3; } case 3: { Vector128<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 3))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 3)); goto case 2; } case 2: { Vector128<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128<T>.Count * 2))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128<T>.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128<T>.Count); goto case 0; } case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); break; } } } static void Vectorized256(ref T xRef, T y, ref T dRef, nuint remainder) { ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data Vector256<T> yVec = Vector256.Create(y); Vector256<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), yVec); Vector256<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count)), yVec); if (remainder > (uint)(Vector256<T>.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. fixed (T* px = &xRef) fixed (T* pd = &dRef) { T* xPtr = px; T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { // Compute by how many elements we're misaligned and adjust the pointers accordingly // // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. nuint misalignment = ((uint)sizeof(Vector256<T>) - ((nuint)dPtr % (uint)sizeof(Vector256<T>))) / (uint)sizeof(T); xPtr += misalignment; dPtr += misalignment; Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256<T>)) == 0); remainder -= misalignment; } Vector256<T> vector1; Vector256<T> vector2; Vector256<T> vector3; Vector256<T> vector4; if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. while (remainder >= (uint)(Vector256<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 0))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 1))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 2))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 3))), yVec); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 0)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 1)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 2)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 4))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 5))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 6))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 7))), yVec); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 4)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 5)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 6)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector256<T>.Count * 8); dPtr += (uint)(Vector256<T>.Count * 8); remainder -= (uint)(Vector256<T>.Count * 8); } } else { while (remainder >= (uint)(Vector256<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 0))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 1))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 2))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 3))), yVec); vector1.Store(dPtr + (uint)(Vector256<T>.Count * 0)); vector2.Store(dPtr + (uint)(Vector256<T>.Count * 1)); vector3.Store(dPtr + (uint)(Vector256<T>.Count * 2)); vector4.Store(dPtr + (uint)(Vector256<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 4))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 5))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 6))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256<T>.Count * 7))), yVec); vector1.Store(dPtr + (uint)(Vector256<T>.Count * 4)); vector2.Store(dPtr + (uint)(Vector256<T>.Count * 5)); vector3.Store(dPtr + (uint)(Vector256<T>.Count * 6)); vector4.Store(dPtr + (uint)(Vector256<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector256<T>.Count * 8); dPtr += (uint)(Vector256<T>.Count * 8); remainder -= (uint)(Vector256<T>.Count * 8); } } // Adjusting the refs here allows us to avoid pinning for very small inputs xRef = ref *xPtr; dRef = ref *dPtr; } } // Process the remaining [Count, Count * 8] elements via a jump table // // Unless the original length was an exact multiple of Count, then we'll // end up reprocessing a couple elements in case 1 for end. We'll also // potentially reprocess a few elements in case 0 for beg, to handle any // data before the first aligned address. nuint endIndex = remainder; remainder = (remainder + (uint)(Vector256<T>.Count - 1)) & (nuint)(-Vector256<T>.Count); switch (remainder / (uint)Vector256<T>.Count) { case 8: { Vector256<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 8))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 8)); goto case 7; } case 7: { Vector256<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 7))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 7)); goto case 6; } case 6: { Vector256<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 6))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 6)); goto case 5; } case 5: { Vector256<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 5))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 5)); goto case 4; } case 4: { Vector256<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 4))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 4)); goto case 3; } case 3: { Vector256<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 3))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 3)); goto case 2; } case 2: { Vector256<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256<T>.Count * 2))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256<T>.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256<T>.Count); goto case 0; } case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); break; } } } static void Vectorized512(ref T xRef, T y, ref T dRef, nuint remainder) { ref T dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data Vector512<T> yVec = Vector512.Create(y); Vector512<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)), yVec); Vector512<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512<T>.Count)), yVec); if (remainder > (uint)(Vector512<T>.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. fixed (T* px = &xRef) fixed (T* pd = &dRef) { T* xPtr = px; T* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; if (canAlign) { // Compute by how many elements we're misaligned and adjust the pointers accordingly // // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. nuint misalignment = ((uint)sizeof(Vector512<T>) - ((nuint)dPtr % (uint)sizeof(Vector512<T>))) / (uint)sizeof(T); xPtr += misalignment; dPtr += misalignment; Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512<T>)) == 0); remainder -= misalignment; } Vector512<T> vector1; Vector512<T> vector2; Vector512<T> vector3; Vector512<T> vector4; if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. while (remainder >= (uint)(Vector512<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 0))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 1))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 2))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 3))), yVec); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 0)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 1)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 2)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 4))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 5))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 6))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 7))), yVec); vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 4)); vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 5)); vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 6)); vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector512<T>.Count * 8); dPtr += (uint)(Vector512<T>.Count * 8); remainder -= (uint)(Vector512<T>.Count * 8); } } else { while (remainder >= (uint)(Vector512<T>.Count * 8)) { // We load, process, and store the first four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 0))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 1))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 2))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 3))), yVec); vector1.Store(dPtr + (uint)(Vector512<T>.Count * 0)); vector2.Store(dPtr + (uint)(Vector512<T>.Count * 1)); vector3.Store(dPtr + (uint)(Vector512<T>.Count * 2)); vector4.Store(dPtr + (uint)(Vector512<T>.Count * 3)); // We load, process, and store the next four vectors vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 4))), yVec); vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 5))), yVec); vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 6))), yVec); vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512<T>.Count * 7))), yVec); vector1.Store(dPtr + (uint)(Vector512<T>.Count * 4)); vector2.Store(dPtr + (uint)(Vector512<T>.Count * 5)); vector3.Store(dPtr + (uint)(Vector512<T>.Count * 6)); vector4.Store(dPtr + (uint)(Vector512<T>.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. xPtr += (uint)(Vector512<T>.Count * 8); dPtr += (uint)(Vector512<T>.Count * 8); remainder -= (uint)(Vector512<T>.Count * 8); } } // Adjusting the refs here allows us to avoid pinning for very small inputs xRef = ref *xPtr; dRef = ref *dPtr; } } // Process the remaining [Count, Count * 8] elements via a jump table // // Unless the original length was an exact multiple of Count, then we'll // end up reprocessing a couple elements in case 1 for end. We'll also // potentially reprocess a few elements in case 0 for beg, to handle any // data before the first aligned address. nuint endIndex = remainder; remainder = (remainder + (uint)(Vector512<T>.Count - 1)) & (nuint)(-Vector512<T>.Count); switch (remainder / (uint)Vector512<T>.Count) { case 8: { Vector512<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 8))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 8)); goto case 7; } case 7: { Vector512<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 7))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 7)); goto case 6; } case 6: { Vector512<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 6))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 6)); goto case 5; } case 5: { Vector512<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 5))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 5)); goto case 4; } case 4: { Vector512<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 4))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 4)); goto case 3; } case 3: { Vector512<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 3))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 3)); goto case 2; } case 2: { Vector512<T> vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512<T>.Count * 2))), yVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512<T>.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512<T>.Count); goto case 0; } case 0: { // Store the first block, which includes any elements preceding the first aligned block beg.StoreUnsafe(ref dRefBeg); break; } } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall(ref T xRef, T y, ref T dRef, nuint remainder) { if (sizeof(T) == 1) { VectorizedSmall1(ref xRef, y, ref dRef, remainder); } else if (sizeof(T) == 2) { VectorizedSmall2(ref xRef, y, ref dRef, remainder); } else if (sizeof(T) == 4) { VectorizedSmall4(ref xRef, y, ref dRef, remainder); } else { Debug.Assert(sizeof(T) == 8); VectorizedSmall8(ref xRef, y, ref dRef, remainder); } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall1(ref T xRef, T y, ref T dRef, nuint remainder) { Debug.Assert(sizeof(T) == 1); switch (remainder) { // Two Vector256's worth of data, with at least one element overlapping. case 63: case 62: case 61: case 60: case 59: case 58: case 57: case 56: case 55: case 54: case 53: case 52: case 51: case 50: case 49: case 48: case 47: case 46: case 45: case 44: case 43: case 42: case 41: case 40: case 39: case 38: case 37: case 36: case 35: case 34: case 33: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> yVec = Vector256.Create(y); Vector256<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), yVec); Vector256<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count)), yVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256<T>.Count); break; } // One Vector256's worth of data. case 32: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), Vector256.Create(y)); beg.StoreUnsafe(ref dRef); break; } // Two Vector128's worth of data, with at least one element overlapping. case 31: case 30: case 29: case 28: case 27: case 26: case 25: case 24: case 23: case 22: case 21: case 20: case 19: case 18: case 17: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> yVec = Vector128.Create(y); Vector128<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), yVec); Vector128<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count)), yVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128<T>.Count); break; } // One Vector128's worth of data. case 16: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), Vector128.Create(y)); beg.StoreUnsafe(ref dRef); break; } // Cases that are smaller than a single vector. No SIMD; just jump to the length and fall through each // case to unroll the whole processing. case 15: Unsafe.Add(ref dRef, 14) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 14)), y); goto case 14; case 14: Unsafe.Add(ref dRef, 13) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 13)), y); goto case 13; case 13: Unsafe.Add(ref dRef, 12) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 12)), y); goto case 12; case 12: Unsafe.Add(ref dRef, 11) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 11)), y); goto case 11; case 11: Unsafe.Add(ref dRef, 10) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 10)), y); goto case 10; case 10: Unsafe.Add(ref dRef, 9) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 9)), y); goto case 9; case 9: Unsafe.Add(ref dRef, 8) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 8)), y); goto case 8; case 8: Unsafe.Add(ref dRef, 7) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 7)), y); goto case 7; case 7: Unsafe.Add(ref dRef, 6) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 6)), y); goto case 6; case 6: Unsafe.Add(ref dRef, 5) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 5)), y); goto case 5; case 5: Unsafe.Add(ref dRef, 4) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 4)), y); goto case 4; case 4: Unsafe.Add(ref dRef, 3) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 3)), y); goto case 3; case 3: Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), y); goto case 2; case 2: Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), y); goto case 1; case 1: dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); goto case 0; case 0: break; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall2(ref T xRef, T y, ref T dRef, nuint remainder) { Debug.Assert(sizeof(T) == 2); switch (remainder) { // Two Vector256's worth of data, with at least one element overlapping. case 31: case 30: case 29: case 28: case 27: case 26: case 25: case 24: case 23: case 22: case 21: case 20: case 19: case 18: case 17: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> yVec = Vector256.Create(y); Vector256<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), yVec); Vector256<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count)), yVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256<T>.Count); break; } // One Vector256's worth of data. case 16: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), Vector256.Create(y)); beg.StoreUnsafe(ref dRef); break; } // Two Vector128's worth of data, with at least one element overlapping. case 15: case 14: case 13: case 12: case 11: case 10: case 9: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> yVec = Vector128.Create(y); Vector128<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), yVec); Vector128<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count)), yVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128<T>.Count); break; } // One Vector128's worth of data. case 8: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), Vector128.Create(y)); beg.StoreUnsafe(ref dRef); break; } // Cases that are smaller than a single vector. No SIMD; just jump to the length and fall through each // case to unroll the whole processing. case 7: Unsafe.Add(ref dRef, 6) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 6)), y); goto case 6; case 6: Unsafe.Add(ref dRef, 5) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 5)), y); goto case 5; case 5: Unsafe.Add(ref dRef, 4) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 4)), y); goto case 4; case 4: Unsafe.Add(ref dRef, 3) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 3)), y); goto case 3; case 3: Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), y); goto case 2; case 2: Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), y); goto case 1; case 1: dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); goto case 0; case 0: break; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall4(ref T xRef, T y, ref T dRef, nuint remainder) { Debug.Assert(sizeof(T) == 4); switch (remainder) { case 15: case 14: case 13: case 12: case 11: case 10: case 9: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> yVec = Vector256.Create(y); Vector256<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), yVec); Vector256<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count)), yVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256<T>.Count); break; } case 8: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), Vector256.Create(y)); beg.StoreUnsafe(ref dRef); break; } case 7: case 6: case 5: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> yVec = Vector128.Create(y); Vector128<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), yVec); Vector128<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count)), yVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128<T>.Count); break; } case 4: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), Vector128.Create(y)); beg.StoreUnsafe(ref dRef); break; } case 3: { Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), y); goto case 2; } case 2: { Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), y); goto case 1; } case 1: { dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); goto case 0; } case 0: { break; } } } [MethodImpl(MethodImplOptions.AggressiveInlining)] static void VectorizedSmall8(ref T xRef, T y, ref T dRef, nuint remainder) { Debug.Assert(sizeof(T) == 8); switch (remainder) { case 7: case 6: case 5: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> yVec = Vector256.Create(y); Vector256<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), yVec); Vector256<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256<T>.Count)), yVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256<T>.Count); break; } case 4: { Debug.Assert(Vector256.IsHardwareAccelerated); Vector256<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), Vector256.Create(y)); beg.StoreUnsafe(ref dRef); break; } case 3: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> yVec = Vector128.Create(y); Vector128<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), yVec); Vector128<T> end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128<T>.Count)), yVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128<T>.Count); break; } case 2: { Debug.Assert(Vector128.IsHardwareAccelerated); Vector128<T> beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), Vector128.Create(y)); beg.StoreUnsafe(ref dRef); break; } case 1: { dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); goto case 0; } case 0: { break; } } } } /// <summary>Aggregates all of the elements in the <paramref name="x"/> into a single value.</summary> /// <typeparam name="T">The element type.</typeparam> /// <typeparam name="TAggregate">Specifies the operation to be performed on each pair of values.</typeparam> [MethodImpl(MethodImplOptions.AggressiveInlining)] private static T HorizontalAggregate<T, TAggregate>(Vector256<T> x) where TAggregate : struct, IBinaryOperator<T> => HorizontalAggregate<T, TAggregate>(TAggregate.Invoke(x.GetLower(), x.GetUpper())); /// <summary>Aggregates all of the elements in the <paramref name="x"/> into a single value.</summary> /// <typeparam name="T">The element type.</typeparam> /// <typeparam name="TAggregate">Specifies the operation to be performed on each pair of values.</typeparam> [MethodImpl(MethodImplOptions.AggressiveInlining)] private static T HorizontalAggregate<T, TAggregate>(Vector512<T> x) where TAggregate : struct, IBinaryOperator<T> => HorizontalAggregate<T, TAggregate>(TAggregate.Invoke(x.GetLower(), x.GetUpper())); /// <summary>Aggregates all of the elements in the <paramref name="x"/> into a single value.</summary> /// <typeparam name="T">The element type.</typeparam> /// <typeparam name="TAggregate">Specifies the operation to be performed on each pair of values.</typeparam> [MethodImpl(MethodImplOptions.AggressiveInlining)] private static T HorizontalAggregate<T, TAggregate>(Vector128<T> x) where TAggregate : struct, IBinaryOperator<T> { // We need to do log2(count) operations to compute the total sum if (sizeof(T) == 1) { x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsByte(), Vector128.Create((byte)8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7)).As<byte, T>()); x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsByte(), Vector128.Create((byte)4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>()); x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsByte(), Vector128.Create((byte)2, 3, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>()); x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsByte(), Vector128.Create((byte)1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>()); } else if (sizeof(T) == 2) { x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As<short, T>()); x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt16(), Vector128.Create(2, 3, 0, 1, 4, 5, 6, 7)).As<short, T>()); x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt16(), Vector128.Create(1, 0, 2, 3, 4, 5, 6, 7)).As<short, T>()); } else if (sizeof(T) == 4) { x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt32(), Vector128.Create(2, 3, 0, 1)).As<int, T>()); x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt32(), Vector128.Create(1, 0, 3, 2)).As<int, T>()); } else { Debug.Assert(sizeof(T) == 8); x = TAggregate.Invoke(x, Vector128.Shuffle(x.AsInt64(), Vector128.Create(1, 0)).As<long, T>()); } return x.ToScalar(); } private readonly struct SwappedBinaryOperator<TOperator, T> : IBinaryOperator<T> where TOperator : IBinaryOperator<T> { public static bool Vectorizable => TOperator.Vectorizable; public static T Invoke(T x, T y) => TOperator.Invoke(y, x); public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y) => TOperator.Invoke(y, x); public static Vector256<T> Invoke(Vector256<T> x, Vector256<T> y) => TOperator.Invoke(y, x); public static Vector512<T> Invoke(Vector512<T> x, Vector512<T> y) => TOperator.Invoke(y, x); } } } |