File: System\Numerics\Tensors\netcore\TensorPrimitives.IndexOfMax.cs
Web Access
Project: src\src\libraries\System.Numerics.Tensors\src\System.Numerics.Tensors.csproj (System.Numerics.Tensors)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
 
namespace System.Numerics.Tensors
{
    public static partial class TensorPrimitives
    {
        /// <summary>Searches for the index of the largest number in the specified tensor.</summary>
        /// <param name="x">The tensor, represented as a span.</param>
        /// <returns>The index of the maximum element in <paramref name="x"/>, or -1 if <paramref name="x"/> is empty.</returns>
        /// <remarks>
        /// <para>
        /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If any value equal to NaN
        /// is present, the index of the first is returned. Positive 0 is considered greater than negative 0.
        /// </para>
        /// <para>
        /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
        /// operating systems or architectures.
        /// </para>
        /// </remarks>
        public static int IndexOfMax<T>(ReadOnlySpan<T> x)
            where T : INumber<T> =>
            IndexOfMinMaxCore<T, IndexOfMaxOperator<T>>(x);
 
        /// <summary>Returns the index of MathF.Max(x, y)</summary>
        internal readonly struct IndexOfMaxOperator<T> : IIndexOfOperator<T> where T : INumber<T>
        {
            [MethodImpl(MethodImplOptions.AggressiveInlining)]
            public static void Invoke(ref Vector128<T> result, Vector128<T> current, ref Vector128<T> resultIndex, Vector128<T> currentIndex)
            {
                Vector128<T> useResult = Vector128.GreaterThan(result, current);
                Vector128<T> equalMask = Vector128.Equals(result, current);
 
                if (equalMask != Vector128<T>.Zero)
                {
                    Vector128<T> lessThanIndexMask = IndexLessThan(resultIndex, currentIndex);
                    if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                    {
                        // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(current));
                        Vector128<T> currentNegative = IsNegative(current);
                        Vector128<T> sameSign = Vector128.Equals(IsNegative(result).AsInt32(), currentNegative.AsInt32()).As<int, T>();
                        useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, currentNegative);
                    }
                    else
                    {
                        useResult |= equalMask & lessThanIndexMask;
                    }
                }
 
                result = ElementWiseSelect(useResult, result, current);
                resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex);
            }
 
            [MethodImpl(MethodImplOptions.AggressiveInlining)]
            public static void Invoke(ref Vector256<T> result, Vector256<T> current, ref Vector256<T> resultIndex, Vector256<T> currentIndex)
            {
                Vector256<T> useResult = Vector256.GreaterThan(result, current);
                Vector256<T> equalMask = Vector256.Equals(result, current);
 
                if (equalMask != Vector256<T>.Zero)
                {
                    Vector256<T> lessThanIndexMask = IndexLessThan(resultIndex, currentIndex);
                    if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                    {
                        // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(current));
                        Vector256<T> currentNegative = IsNegative(current);
                        Vector256<T> sameSign = Vector256.Equals(IsNegative(result).AsInt32(), currentNegative.AsInt32()).As<int, T>();
                        useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, currentNegative);
                    }
                    else
                    {
                        useResult |= equalMask & lessThanIndexMask;
                    }
                }
 
                result = ElementWiseSelect(useResult, result, current);
                resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex);
            }
 
            [MethodImpl(MethodImplOptions.AggressiveInlining)]
            public static void Invoke(ref Vector512<T> result, Vector512<T> current, ref Vector512<T> resultIndex, Vector512<T> currentIndex)
            {
                Vector512<T> useResult = Vector512.GreaterThan(result, current);
                Vector512<T> equalMask = Vector512.Equals(result, current);
 
                if (equalMask != Vector512<T>.Zero)
                {
                    Vector512<T> lessThanIndexMask = IndexLessThan(resultIndex, currentIndex);
                    if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                    {
                        // bool useResult = equal && ((IsNegative(result) == IsNegative(current)) ? (resultIndex < currentIndex) : IsNegative(current));
                        Vector512<T> currentNegative = IsNegative(current);
                        Vector512<T> sameSign = Vector512.Equals(IsNegative(result).AsInt32(), currentNegative.AsInt32()).As<int, T>();
                        useResult |= equalMask & ElementWiseSelect(sameSign, lessThanIndexMask, currentNegative);
                    }
                    else
                    {
                        useResult |= equalMask & lessThanIndexMask;
                    }
                }
 
                result = ElementWiseSelect(useResult, result, current);
                resultIndex = ElementWiseSelect(useResult, resultIndex, currentIndex);
            }
 
            [MethodImpl(MethodImplOptions.AggressiveInlining)]
            public static int Invoke(ref T result, T current, int resultIndex, int currentIndex)
            {
                if (result == current)
                {
                    bool resultNegative = IsNegative(result);
                    if ((resultNegative == IsNegative(current)) ? (currentIndex < resultIndex) : resultNegative)
                    {
                        result = current;
                        return currentIndex;
                    }
                }
                else if (current > result)
                {
                    result = current;
                    return currentIndex;
                }
 
                return resultIndex;
            }
        }
 
        private static unsafe int IndexOfMinMaxCore<T, TIndexOfMinMax>(ReadOnlySpan<T> x)
    where T : INumber<T>
    where TIndexOfMinMax : struct, IIndexOfOperator<T>
        {
            if (x.IsEmpty)
            {
                return -1;
            }
 
            // This matches the IEEE 754:2019 `maximum`/`minimum` functions.
            // It propagates NaN inputs back to the caller and
            // otherwise returns the index of the greater of the inputs.
            // It treats +0 as greater than -0 as per the specification.
 
            if (Vector512.IsHardwareAccelerated && Vector512<T>.IsSupported && x.Length >= Vector512<T>.Count)
            {
                Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8);
 
                [MethodImpl(MethodImplOptions.AggressiveInlining)]
                static Vector512<T> CreateVector512T(int i) =>
                    sizeof(T) == sizeof(long) ? Vector512.Create((long)i).As<long, T>() :
                    sizeof(T) == sizeof(int) ? Vector512.Create(i).As<int, T>() :
                    sizeof(T) == sizeof(short) ? Vector512.Create((short)i).As<short, T>() :
                    Vector512.Create((byte)i).As<byte, T>();
 
                ref T xRef = ref MemoryMarshal.GetReference(x);
                Vector512<T> resultIndex =
#if NET9_0_OR_GREATER
                    sizeof(T) == sizeof(long) ? Vector512<long>.Indices.As<long, T>() :
                    sizeof(T) == sizeof(int) ? Vector512<int>.Indices.As<int, T>() :
                    sizeof(T) == sizeof(short) ? Vector512<short>.Indices.As<short, T>() :
                    Vector512<byte>.Indices.As<byte, T>();
#else
                    sizeof(T) == sizeof(long) ? Vector512.Create(0L, 1, 2, 3, 4, 5, 6, 7).As<long, T>() :
                    sizeof(T) == sizeof(int) ? Vector512.Create(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15).As<int, T>() :
                    sizeof(T) == sizeof(short) ? Vector512.Create(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31).As<short, T>() :
                    Vector512.Create((byte)0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63).As<byte, T>();
#endif
                Vector512<T> currentIndex = resultIndex;
                Vector512<T> increment = CreateVector512T(Vector512<T>.Count);
 
                // Load the first vector as the initial set of results, and bail immediately
                // to scalar handling if it contains any NaNs (which don't compare equally to themselves).
                Vector512<T> result = Vector512.LoadUnsafe(ref xRef);
                Vector512<T> current;
 
                Vector512<T> nanMask;
                if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                {
                    nanMask = ~Vector512.Equals(result, result);
                    if (nanMask != Vector512<T>.Zero)
                    {
                        return IndexOfFirstMatch(nanMask);
                    }
                }
 
                int oneVectorFromEnd = x.Length - Vector512<T>.Count;
                int i = Vector512<T>.Count;
 
                // Aggregate additional vectors into the result as long as there's at least one full vector left to process.
                while (i <= oneVectorFromEnd)
                {
                    // Load the next vector, and early exit on NaN.
                    current = Vector512.LoadUnsafe(ref xRef, (uint)i);
                    currentIndex += increment;
 
                    if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                    {
                        nanMask = ~Vector512.Equals(current, current);
                        if (nanMask != Vector512<T>.Zero)
                        {
                            return i + IndexOfFirstMatch(nanMask);
                        }
                    }
 
                    TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex);
 
                    i += Vector512<T>.Count;
                }
 
                // If any elements remain, handle them in one final vector.
                if (i != x.Length)
                {
                    current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512<T>.Count));
                    currentIndex += CreateVector512T(x.Length - i);
 
                    if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                    {
                        nanMask = ~Vector512.Equals(current, current);
                        if (nanMask != Vector512<T>.Zero)
                        {
                            int indexInVectorOfFirstMatch = IndexOfFirstMatch(nanMask);
                            return typeof(T) == typeof(double) ?
                                (int)(long)(object)currentIndex.As<T, long>()[indexInVectorOfFirstMatch] :
                                (int)(object)currentIndex.As<T, int>()[indexInVectorOfFirstMatch];
                        }
                    }
 
                    TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex);
                }
 
                // Aggregate the lanes in the vector to create the final scalar result.
                return IndexOfFinalAggregate<T, TIndexOfMinMax>(result, resultIndex);
            }
 
            if (Vector256.IsHardwareAccelerated && Vector256<T>.IsSupported && x.Length >= Vector256<T>.Count)
            {
                Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8);
 
                [MethodImpl(MethodImplOptions.AggressiveInlining)]
                static Vector256<T> CreateVector256T(int i) =>
                    sizeof(T) == sizeof(long) ? Vector256.Create((long)i).As<long, T>() :
                    sizeof(T) == sizeof(int) ? Vector256.Create(i).As<int, T>() :
                    sizeof(T) == sizeof(short) ? Vector256.Create((short)i).As<short, T>() :
                    Vector256.Create((byte)i).As<byte, T>();
 
                ref T xRef = ref MemoryMarshal.GetReference(x);
                Vector256<T> resultIndex =
#if NET9_0_OR_GREATER
                    sizeof(T) == sizeof(long) ? Vector256<long>.Indices.As<long, T>() :
                    sizeof(T) == sizeof(int) ? Vector256<int>.Indices.As<int, T>() :
                    sizeof(T) == sizeof(short) ? Vector256<short>.Indices.As<short, T>() :
                    Vector256<byte>.Indices.As<byte, T>();
#else
                    sizeof(T) == sizeof(long) ? Vector256.Create(0L, 1, 2, 3).As<long, T>() :
                    sizeof(T) == sizeof(int) ? Vector256.Create(0, 1, 2, 3, 4, 5, 6, 7).As<int, T>() :
                    sizeof(T) == sizeof(short) ? Vector256.Create(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15).As<short, T>() :
                    Vector256.Create((byte)0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31).As<byte, T>();
#endif
                Vector256<T> currentIndex = resultIndex;
                Vector256<T> increment = CreateVector256T(Vector256<T>.Count);
 
                // Load the first vector as the initial set of results, and bail immediately
                // to scalar handling if it contains any NaNs (which don't compare equally to themselves).
                Vector256<T> result = Vector256.LoadUnsafe(ref xRef);
                Vector256<T> current;
 
                Vector256<T> nanMask;
                if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                {
                    nanMask = ~Vector256.Equals(result, result);
                    if (nanMask != Vector256<T>.Zero)
                    {
                        return IndexOfFirstMatch(nanMask);
                    }
                }
 
                int oneVectorFromEnd = x.Length - Vector256<T>.Count;
                int i = Vector256<T>.Count;
 
                // Aggregate additional vectors into the result as long as there's at least one full vector left to process.
                while (i <= oneVectorFromEnd)
                {
                    // Load the next vector, and early exit on NaN.
                    current = Vector256.LoadUnsafe(ref xRef, (uint)i);
                    currentIndex += increment;
 
                    if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                    {
                        nanMask = ~Vector256.Equals(current, current);
                        if (nanMask != Vector256<T>.Zero)
                        {
                            return i + IndexOfFirstMatch(nanMask);
                        }
                    }
 
                    TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex);
 
                    i += Vector256<T>.Count;
                }
 
                // If any elements remain, handle them in one final vector.
                if (i != x.Length)
                {
                    current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256<T>.Count));
                    currentIndex += CreateVector256T(x.Length - i);
 
                    if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                    {
                        nanMask = ~Vector256.Equals(current, current);
                        if (nanMask != Vector256<T>.Zero)
                        {
                            int indexInVectorOfFirstMatch = IndexOfFirstMatch(nanMask);
                            return typeof(T) == typeof(double) ?
                                (int)(long)(object)currentIndex.As<T, long>()[indexInVectorOfFirstMatch] :
                                (int)(object)currentIndex.As<T, int>()[indexInVectorOfFirstMatch];
                        }
                    }
 
                    TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex);
                }
 
                // Aggregate the lanes in the vector to create the final scalar result.
                return IndexOfFinalAggregate<T, TIndexOfMinMax>(result, resultIndex);
            }
 
            if (Vector128.IsHardwareAccelerated && Vector128<T>.IsSupported && x.Length >= Vector128<T>.Count)
            {
                Debug.Assert(sizeof(T) is 1 or 2 or 4 or 8);
 
                [MethodImpl(MethodImplOptions.AggressiveInlining)]
                static Vector128<T> CreateVector128T(int i) =>
                    sizeof(T) == sizeof(long) ? Vector128.Create((long)i).As<long, T>() :
                    sizeof(T) == sizeof(int) ? Vector128.Create(i).As<int, T>() :
                    sizeof(T) == sizeof(short) ? Vector128.Create((short)i).As<short, T>() :
                    Vector128.Create((byte)i).As<byte, T>();
 
                ref T xRef = ref MemoryMarshal.GetReference(x);
                Vector128<T> resultIndex =
#if NET9_0_OR_GREATER
                    sizeof(T) == sizeof(long) ? Vector128<long>.Indices.As<long, T>() :
                    sizeof(T) == sizeof(int) ? Vector128<int>.Indices.As<int, T>() :
                    sizeof(T) == sizeof(short) ? Vector128<short>.Indices.As<short, T>() :
                    Vector128<byte>.Indices.As<byte, T>();
#else
                    sizeof(T) == sizeof(long) ? Vector128.Create(0L, 1).As<long, T>() :
                    sizeof(T) == sizeof(int) ? Vector128.Create(0, 1, 2, 3).As<int, T>() :
                    sizeof(T) == sizeof(short) ? Vector128.Create(0, 1, 2, 3, 4, 5, 6, 7).As<short, T>() :
                    Vector128.Create((byte)0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15).As<byte, T>();
#endif
                Vector128<T> currentIndex = resultIndex;
                Vector128<T> increment = CreateVector128T(Vector128<T>.Count);
 
                // Load the first vector as the initial set of results, and bail immediately
                // to scalar handling if it contains any NaNs (which don't compare equally to themselves).
                Vector128<T> result = Vector128.LoadUnsafe(ref xRef);
                Vector128<T> current;
 
                Vector128<T> nanMask;
                if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                {
                    nanMask = ~Vector128.Equals(result, result);
                    if (nanMask != Vector128<T>.Zero)
                    {
                        return IndexOfFirstMatch(nanMask);
                    }
                }
 
                int oneVectorFromEnd = x.Length - Vector128<T>.Count;
                int i = Vector128<T>.Count;
 
                // Aggregate additional vectors into the result as long as there's at least one full vector left to process.
                while (i <= oneVectorFromEnd)
                {
                    // Load the next vector, and early exit on NaN.
                    current = Vector128.LoadUnsafe(ref xRef, (uint)i);
                    currentIndex += increment;
 
                    if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                    {
                        nanMask = ~Vector128.Equals(current, current);
                        if (nanMask != Vector128<T>.Zero)
                        {
                            return i + IndexOfFirstMatch(nanMask);
                        }
                    }
 
                    TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex);
 
                    i += Vector128<T>.Count;
                }
 
                // If any elements remain, handle them in one final vector.
                if (i != x.Length)
                {
                    current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128<T>.Count));
                    currentIndex += CreateVector128T(x.Length - i);
 
                    if (typeof(T) == typeof(float) || typeof(T) == typeof(double))
                    {
                        nanMask = ~Vector128.Equals(current, current);
                        if (nanMask != Vector128<T>.Zero)
                        {
                            int indexInVectorOfFirstMatch = IndexOfFirstMatch(nanMask);
                            return typeof(T) == typeof(double) ?
                                (int)(long)(object)currentIndex.As<T, long>()[indexInVectorOfFirstMatch] :
                                (int)(object)currentIndex.As<T, int>()[indexInVectorOfFirstMatch];
                        }
                    }
 
                    TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, currentIndex);
                }
 
                // Aggregate the lanes in the vector to create the final scalar result.
                return IndexOfFinalAggregate<T, TIndexOfMinMax>(result, resultIndex);
            }
 
            // Scalar path used when either vectorization is not supported or the input is too small to vectorize.
            T curResult = x[0];
            int curIn = 0;
            if (T.IsNaN(curResult))
            {
                return curIn;
            }
 
            for (int i = 1; i < x.Length; i++)
            {
                T current = x[i];
                if (T.IsNaN(current))
                {
                    return i;
                }
 
                curIn = TIndexOfMinMax.Invoke(ref curResult, current, curIn, i);
            }
 
            return curIn;
        }
 
        private static int IndexOfFirstMatch<T>(Vector128<T> mask) =>
            BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits());
 
        private static int IndexOfFirstMatch<T>(Vector256<T> mask) =>
            BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits());
 
        private static int IndexOfFirstMatch<T>(Vector512<T> mask) =>
            BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits());
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static unsafe Vector256<T> IndexLessThan<T>(Vector256<T> indices1, Vector256<T> indices2) =>
            sizeof(T) == sizeof(long) ? Vector256.LessThan(indices1.AsInt64(), indices2.AsInt64()).As<long, T>() :
            sizeof(T) == sizeof(int) ? Vector256.LessThan(indices1.AsInt32(), indices2.AsInt32()).As<int, T>() :
            sizeof(T) == sizeof(short) ? Vector256.LessThan(indices1.AsInt16(), indices2.AsInt16()).As<short, T>() :
            Vector256.LessThan(indices1.AsByte(), indices2.AsByte()).As<byte, T>();
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static unsafe Vector512<T> IndexLessThan<T>(Vector512<T> indices1, Vector512<T> indices2) =>
            sizeof(T) == sizeof(long) ? Vector512.LessThan(indices1.AsInt64(), indices2.AsInt64()).As<long, T>() :
            sizeof(T) == sizeof(int) ? Vector512.LessThan(indices1.AsInt32(), indices2.AsInt32()).As<int, T>() :
            sizeof(T) == sizeof(short) ? Vector512.LessThan(indices1.AsInt16(), indices2.AsInt16()).As<short, T>() :
            Vector512.LessThan(indices1.AsByte(), indices2.AsByte()).As<byte, T>();
 
        /// <summary>Gets whether the specified <see cref="float"/> is negative.</summary>
        private static bool IsNegative<T>(T f) where T : INumberBase<T> => T.IsNegative(f);
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static unsafe Vector128<T> ElementWiseSelect<T>(Vector128<T> mask, Vector128<T> left, Vector128<T> right)
        {
            if (Sse41.IsSupported)
            {
                if (typeof(T) == typeof(float)) return Sse41.BlendVariable(left.AsSingle(), right.AsSingle(), (~mask).AsSingle()).As<float, T>();
                if (typeof(T) == typeof(double)) return Sse41.BlendVariable(left.AsDouble(), right.AsDouble(), (~mask).AsDouble()).As<double, T>();
 
                if (sizeof(T) == 1) return Sse41.BlendVariable(left.AsByte(), right.AsByte(), (~mask).AsByte()).As<byte, T>();
                if (sizeof(T) == 2) return Sse41.BlendVariable(left.AsUInt16(), right.AsUInt16(), (~mask).AsUInt16()).As<ushort, T>();
                if (sizeof(T) == 4) return Sse41.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As<uint, T>();
                if (sizeof(T) == 8) return Sse41.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As<ulong, T>();
            }
 
            return Vector128.ConditionalSelect(mask, left, right);
        }
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static unsafe Vector256<T> ElementWiseSelect<T>(Vector256<T> mask, Vector256<T> left, Vector256<T> right)
        {
            if (Avx2.IsSupported)
            {
                if (typeof(T) == typeof(float)) return Avx2.BlendVariable(left.AsSingle(), right.AsSingle(), (~mask).AsSingle()).As<float, T>();
                if (typeof(T) == typeof(double)) return Avx2.BlendVariable(left.AsDouble(), right.AsDouble(), (~mask).AsDouble()).As<double, T>();
 
                if (sizeof(T) == 1) return Avx2.BlendVariable(left.AsByte(), right.AsByte(), (~mask).AsByte()).As<byte, T>();
                if (sizeof(T) == 2) return Avx2.BlendVariable(left.AsUInt16(), right.AsUInt16(), (~mask).AsUInt16()).As<ushort, T>();
                if (sizeof(T) == 4) return Avx2.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As<uint, T>();
                if (sizeof(T) == 8) return Avx2.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As<ulong, T>();
            }
 
            return Vector256.ConditionalSelect(mask, left, right);
        }
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static unsafe Vector512<T> ElementWiseSelect<T>(Vector512<T> mask, Vector512<T> left, Vector512<T> right)
        {
            if (Avx512F.IsSupported)
            {
                if (typeof(T) == typeof(float)) return Avx512F.BlendVariable(left.AsSingle(), right.AsSingle(), (~mask).AsSingle()).As<float, T>();
                if (typeof(T) == typeof(double)) return Avx512F.BlendVariable(left.AsDouble(), right.AsDouble(), (~mask).AsDouble()).As<double, T>();
 
                if (sizeof(T) == 4) return Avx512F.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As<uint, T>();
                if (sizeof(T) == 8) return Avx512F.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As<ulong, T>();
            }
 
            return Vector512.ConditionalSelect(mask, left, right);
        }
    }
}