// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.Intrinsics; namespace System.Numerics.Tensors { public static unsafe partial class TensorPrimitives { private interface IIndexOfMinMaxOperator<T> { static abstract T Aggregate(Vector128<T> value); static abstract T Aggregate(Vector256<T> value); static abstract T Aggregate(Vector512<T> value); static abstract bool Compare(T x, T y); static abstract Vector128<T> Compare(Vector128<T> x, Vector128<T> y); static abstract Vector256<T> Compare(Vector256<T> x, Vector256<T> y); static abstract Vector512<T> Compare(Vector512<T> x, Vector512<T> y); } private static int IndexOfMinMaxCore<T, TOperator>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> { if (x.IsEmpty) { return -1; } if (Vector512.IsHardwareAccelerated && Vector512<T>.IsSupported && x.Length >= Vector512<T>.Count) { return sizeof(T) == 8 ? IndexOfMinMaxVectorized512Size4Plus<T, TOperator, ulong>(x) : sizeof(T) == 4 ? IndexOfMinMaxVectorized512Size4Plus<T, TOperator, uint>(x) : sizeof(T) == 2 ? IndexOfMinMaxVectorized512Size2<T, TOperator>(x) : IndexOfMinMaxVectorized512Size1<T, TOperator>(x); } if (Vector256.IsHardwareAccelerated && Vector256<T>.IsSupported && x.Length >= Vector256<T>.Count) { return sizeof(T) == 8 ? IndexOfMinMaxVectorized256Size4Plus<T, TOperator, ulong>(x) : sizeof(T) == 4 ? IndexOfMinMaxVectorized256Size4Plus<T, TOperator, uint>(x) : sizeof(T) == 2 ? IndexOfMinMaxVectorized256Size2<T, TOperator>(x) : IndexOfMinMaxVectorized256Size1<T, TOperator>(x); } if (Vector128.IsHardwareAccelerated && Vector128<T>.IsSupported && x.Length >= Vector128<T>.Count) { return sizeof(T) == 8 ? IndexOfMinMaxVectorized128Size4Plus<T, TOperator, ulong>(x) : sizeof(T) == 4 ? IndexOfMinMaxVectorized128Size4Plus<T, TOperator, uint>(x) : sizeof(T) == 2 ? IndexOfMinMaxVectorized128Size2<T, TOperator>(x) : IndexOfMinMaxVectorized128Size1<T, TOperator>(x); } return IndexOfMinMaxFallback<T, TOperator>(x); } private static int IndexOfMinMaxFallback<T, TOperator>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> { T result = x[0]; int resultIndex = 0; if (T.IsNaN(result)) { return resultIndex; } for (int i = 1; i < x.Length; i++) { T current = x[i]; if (T.IsNaN(current)) { return i; } if (TOperator.Compare(current, result)) { result = current; resultIndex = i; } } return resultIndex; } private static int IndexOfMinMaxVectorized128Size4Plus<T, TOperator, TInt>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> where TInt : IBinaryInteger<TInt> { Debug.Assert(sizeof(T) == 4 || sizeof(T) == 8); Debug.Assert(typeof(TInt) == typeof(uint) || typeof(TInt) == typeof(ulong)); Debug.Assert(sizeof(TInt) == sizeof(T)); // Initialize result by reading first vector and quick return if possible. Vector128<T> result = Vector128.Create(x); if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector128<T> nanMask = IsNaN(result); if (nanMask != Vector128<T>.Zero) { return IndexOfFirstMatch(nanMask); } } // Initialize indices. Vector128<TInt> indexIncrement = Vector128.Create(TInt.CreateTruncating(Vector128<TInt>.Count)); Vector128<TInt> resultIndex = Vector128<TInt>.Indices; Vector128<TInt> currentIndex = resultIndex + indexIncrement; ReadOnlySpan<T> span = x.Slice(Vector128<T>.Count); while (!span.IsEmpty) { Vector128<T> current; if (span.Length >= Vector128<T>.Count) { current = Vector128.Create(span); span = span.Slice(Vector128<T>.Count); } else { // Process a final back-shifted to cover remaining elements in x in one vector. int start = x.Length - Vector128<T>.Count; current = Vector128.Create(x.Slice(start)); currentIndex = Vector128.Create(TInt.CreateTruncating(start)) + Vector128<TInt>.Indices; span = ReadOnlySpan<T>.Empty; } // Quick return if possible. if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector128<T> nanMask = IsNaN(current); if (nanMask != Vector128<T>.Zero) { return int.CreateTruncating(currentIndex.ToScalar()) + IndexOfFirstMatch(nanMask); } } // Get mask for which lanes that should have result updated. Vector128<T> mask = TOperator.Compare(current, result); // Update result and indices. result = ElementWiseSelect(mask, current, result); resultIndex = ElementWiseSelect(mask.As<T, TInt>(), currentIndex, resultIndex); currentIndex += indexIncrement; } { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); Vector128<TInt> aggMask = ~Vector128.Equals(result.As<T, TInt>(), Vector128.Create(aggResult).As<T, TInt>()); Vector128<TInt> aggIndex = resultIndex | aggMask; return int.CreateTruncating(HorizontalAggregate<TInt, MinOperator<TInt>>(aggIndex)); } } private static int IndexOfMinMaxVectorized128Size2<T, TOperator>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> { Debug.Assert(sizeof(T) == 2); // Initialize result by reading first vector and quick return if possible. Vector128<T> result = Vector128.Create(x); if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector128<T> nanMask = IsNaN(result); if (nanMask != Vector128<T>.Zero) { return IndexOfFirstMatch(nanMask); } } // Initialize indices. Vector128<uint> indexIncrement = Vector128.Create((uint)Vector128<uint>.Count); Vector128<uint> resultIndex1 = Vector128<uint>.Indices; Vector128<uint> resultIndex2 = resultIndex1 + indexIncrement; Vector128<uint> currentIndex = resultIndex2 + indexIncrement; ReadOnlySpan<T> span = x.Slice(Vector128<T>.Count); while (!span.IsEmpty) { Vector128<T> current; if (span.Length >= Vector128<T>.Count) { current = Vector128.Create(span); span = span.Slice(Vector128<T>.Count); } else { // Process a final back-shifted to cover remaining elements in x in one vector. int start = x.Length - Vector128<T>.Count; current = Vector128.Create(x.Slice(start)); currentIndex = Vector128.Create((uint)start) + Vector128<uint>.Indices; span = ReadOnlySpan<T>.Empty; } // Quick return if possible. if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector128<T> nanMask = IsNaN(current); if (nanMask != Vector128<T>.Zero) { return (int)currentIndex.ToScalar() + IndexOfFirstMatch(nanMask); } } // Get mask for which lanes that should have result updated, also widen it for updating the indices. Vector128<T> mask = TOperator.Compare(current, result); (Vector128<int> mask1, Vector128<int> mask2) = Vector128.Widen(mask.AsInt16()); // Update result and indices. result = ElementWiseSelect(mask, current, result); resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); currentIndex += indexIncrement; resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); currentIndex += indexIncrement; } { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); Vector128<short> aggMask = ~Vector128.Equals(result.AsInt16(), Vector128.Create(aggResult).AsInt16()); (Vector128<int> mask1, Vector128<int> mask2) = Vector128.Widen(aggMask); Vector128<uint> aggIndex = resultIndex1 | mask1.AsUInt32(); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); return (int)HorizontalAggregate<uint, MinOperator<uint>>(aggIndex); } } private static int IndexOfMinMaxVectorized128Size1<T, TOperator>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> { Debug.Assert(sizeof(T) == 1); // Initialize result by reading first vector and quick return if possible. Vector128<T> result = Vector128.Create(x); if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector128<T> nanMask = IsNaN(result); if (nanMask != Vector128<T>.Zero) { return IndexOfFirstMatch(nanMask); } } // Initialize indices. Vector128<uint> indexIncrement = Vector128.Create((uint)Vector128<uint>.Count); Vector128<uint> resultIndex1 = Vector128<uint>.Indices; Vector128<uint> resultIndex2 = resultIndex1 + indexIncrement; Vector128<uint> resultIndex3 = resultIndex2 + indexIncrement; Vector128<uint> resultIndex4 = resultIndex3 + indexIncrement; Vector128<uint> currentIndex = resultIndex4 + indexIncrement; ReadOnlySpan<T> span = x.Slice(Vector128<T>.Count); while (!span.IsEmpty) { Vector128<T> current; if (span.Length >= Vector128<T>.Count) { current = Vector128.Create(span); span = span.Slice(Vector128<T>.Count); } else { // Process a final back-shifted to cover remaining elements in x in one vector. int start = x.Length - Vector128<T>.Count; current = Vector128.Create(x.Slice(start)); currentIndex = Vector128.Create((uint)start) + Vector128<uint>.Indices; span = ReadOnlySpan<T>.Empty; } // Quick return if possible. if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector128<T> nanMask = IsNaN(current); if (nanMask != Vector128<T>.Zero) { return (int)currentIndex.ToScalar() + IndexOfFirstMatch(nanMask); } } // Get mask for which lanes that should have result updated, also widen it for updating the indices. Vector128<T> mask = TOperator.Compare(current, result); (Vector128<short> lowerMask, Vector128<short> upperMask) = Vector128.Widen(mask.AsSByte()); (Vector128<int> mask1, Vector128<int> mask2) = Vector128.Widen(lowerMask); (Vector128<int> mask3, Vector128<int> mask4) = Vector128.Widen(upperMask); // Update result and indices. result = ElementWiseSelect(mask, current, result); resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); currentIndex += indexIncrement; resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); currentIndex += indexIncrement; resultIndex3 = ElementWiseSelect(mask3.AsUInt32(), currentIndex, resultIndex3); currentIndex += indexIncrement; resultIndex4 = ElementWiseSelect(mask4.AsUInt32(), currentIndex, resultIndex4); currentIndex += indexIncrement; } { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); Vector128<sbyte> aggMask = ~Vector128.Equals(result.AsSByte(), Vector128.Create(aggResult).AsSByte()); (Vector128<short> lowerMask, Vector128<short> upperMask) = Vector128.Widen(aggMask); (Vector128<int> mask1, Vector128<int> mask2) = Vector128.Widen(lowerMask); (Vector128<int> mask3, Vector128<int> mask4) = Vector128.Widen(upperMask); Vector128<uint> aggIndex = resultIndex1 | mask1.AsUInt32(); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex3 | mask3.AsUInt32()); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex4 | mask4.AsUInt32()); return (int)HorizontalAggregate<uint, MinOperator<uint>>(aggIndex); } } private static int IndexOfMinMaxVectorized256Size4Plus<T, TOperator, TInt>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> where TInt : IBinaryInteger<TInt> { Debug.Assert(sizeof(T) == 4 || sizeof(T) == 8); Debug.Assert(typeof(TInt) == typeof(uint) || typeof(TInt) == typeof(ulong)); Debug.Assert(sizeof(TInt) == sizeof(T)); // Initialize result by reading first vector and quick return if possible. Vector256<T> result = Vector256.Create(x); if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector256<T> nanMask = IsNaN(result); if (nanMask != Vector256<T>.Zero) { return IndexOfFirstMatch(nanMask); } } // Initialize indices. Vector256<TInt> indexIncrement = Vector256.Create(TInt.CreateTruncating(Vector256<TInt>.Count)); Vector256<TInt> resultIndex = Vector256<TInt>.Indices; Vector256<TInt> currentIndex = resultIndex + indexIncrement; ReadOnlySpan<T> span = x.Slice(Vector256<T>.Count); while (!span.IsEmpty) { Vector256<T> current; if (span.Length >= Vector256<T>.Count) { current = Vector256.Create(span); span = span.Slice(Vector256<T>.Count); } else { // Process a final back-shifted to cover remaining elements in x in one vector. int start = x.Length - Vector256<T>.Count; current = Vector256.Create(x.Slice(start)); currentIndex = Vector256.Create(TInt.CreateTruncating(start)) + Vector256<TInt>.Indices; span = ReadOnlySpan<T>.Empty; } // Quick return if possible. if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector256<T> nanMask = IsNaN(current); if (nanMask != Vector256<T>.Zero) { return int.CreateTruncating(currentIndex.ToScalar()) + IndexOfFirstMatch(nanMask); } } // Get mask for which lanes that should have result updated. Vector256<T> mask = TOperator.Compare(current, result); // Update result and indices. result = ElementWiseSelect(mask, current, result); resultIndex = ElementWiseSelect(mask.As<T, TInt>(), currentIndex, resultIndex); currentIndex += indexIncrement; } { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); Vector256<TInt> aggMask = ~Vector256.Equals(result.As<T, TInt>(), Vector256.Create(aggResult).As<T, TInt>()); Vector256<TInt> aggIndex = resultIndex | aggMask; return int.CreateTruncating(HorizontalAggregate<TInt, MinOperator<TInt>>(aggIndex)); } } private static int IndexOfMinMaxVectorized256Size2<T, TOperator>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> { Debug.Assert(sizeof(T) == 2); // Initialize result by reading first vector and quick return if possible. Vector256<T> result = Vector256.Create(x); if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector256<T> nanMask = IsNaN(result); if (nanMask != Vector256<T>.Zero) { return IndexOfFirstMatch(nanMask); } } // Initialize indices. Vector256<uint> indexIncrement = Vector256.Create((uint)Vector256<uint>.Count); Vector256<uint> resultIndex1 = Vector256<uint>.Indices; Vector256<uint> resultIndex2 = resultIndex1 + indexIncrement; Vector256<uint> currentIndex = resultIndex2 + indexIncrement; ReadOnlySpan<T> span = x.Slice(Vector256<T>.Count); while (!span.IsEmpty) { Vector256<T> current; if (span.Length >= Vector256<T>.Count) { current = Vector256.Create(span); span = span.Slice(Vector256<T>.Count); } else { // Process a final back-shifted to cover remaining elements in x in one vector. int start = x.Length - Vector256<T>.Count; current = Vector256.Create(x.Slice(start)); currentIndex = Vector256.Create((uint)start) + Vector256<uint>.Indices; span = ReadOnlySpan<T>.Empty; } // Quick return if possible. if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector256<T> nanMask = IsNaN(current); if (nanMask != Vector256<T>.Zero) { return (int)currentIndex.ToScalar() + IndexOfFirstMatch(nanMask); } } // Get mask for which lanes that should have result updated, also widen it for updating the indices. Vector256<T> mask = TOperator.Compare(current, result); (Vector256<int> mask1, Vector256<int> mask2) = Vector256.Widen(mask.AsInt16()); // Update result and indices. result = ElementWiseSelect(mask, current, result); resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); currentIndex += indexIncrement; resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); currentIndex += indexIncrement; } { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); Vector256<short> aggMask = ~Vector256.Equals(result.AsInt16(), Vector256.Create(aggResult).AsInt16()); (Vector256<int> mask1, Vector256<int> mask2) = Vector256.Widen(aggMask); Vector256<uint> aggIndex = resultIndex1 | mask1.AsUInt32(); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); return (int)HorizontalAggregate<uint, MinOperator<uint>>(aggIndex); } } private static int IndexOfMinMaxVectorized256Size1<T, TOperator>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> { Debug.Assert(sizeof(T) == 1); // Initialize result by reading first vector and quick return if possible. Vector256<T> result = Vector256.Create(x); if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector256<T> nanMask = IsNaN(result); if (nanMask != Vector256<T>.Zero) { return IndexOfFirstMatch(nanMask); } } // Initialize indices. Vector256<uint> indexIncrement = Vector256.Create((uint)Vector256<uint>.Count); Vector256<uint> resultIndex1 = Vector256<uint>.Indices; Vector256<uint> resultIndex2 = resultIndex1 + indexIncrement; Vector256<uint> resultIndex3 = resultIndex2 + indexIncrement; Vector256<uint> resultIndex4 = resultIndex3 + indexIncrement; Vector256<uint> currentIndex = resultIndex4 + indexIncrement; ReadOnlySpan<T> span = x.Slice(Vector256<T>.Count); while (!span.IsEmpty) { Vector256<T> current; if (span.Length >= Vector256<T>.Count) { current = Vector256.Create(span); span = span.Slice(Vector256<T>.Count); } else { // Process a final back-shifted to cover remaining elements in x in one vector. int start = x.Length - Vector256<T>.Count; current = Vector256.Create(x.Slice(start)); currentIndex = Vector256.Create((uint)start) + Vector256<uint>.Indices; span = ReadOnlySpan<T>.Empty; } // Quick return if possible. if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector256<T> nanMask = IsNaN(current); if (nanMask != Vector256<T>.Zero) { return (int)currentIndex.ToScalar() + IndexOfFirstMatch(nanMask); } } // Get mask for which lanes that should have result updated, also widen it for updating the indices. Vector256<T> mask = TOperator.Compare(current, result); (Vector256<short> lowerMask, Vector256<short> upperMask) = Vector256.Widen(mask.AsSByte()); (Vector256<int> mask1, Vector256<int> mask2) = Vector256.Widen(lowerMask); (Vector256<int> mask3, Vector256<int> mask4) = Vector256.Widen(upperMask); // Update result and indices. result = ElementWiseSelect(mask, current, result); resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); currentIndex += indexIncrement; resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); currentIndex += indexIncrement; resultIndex3 = ElementWiseSelect(mask3.AsUInt32(), currentIndex, resultIndex3); currentIndex += indexIncrement; resultIndex4 = ElementWiseSelect(mask4.AsUInt32(), currentIndex, resultIndex4); currentIndex += indexIncrement; } { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); Vector256<sbyte> aggMask = ~Vector256.Equals(result.AsSByte(), Vector256.Create(aggResult).AsSByte()); (Vector256<short> lowerMask, Vector256<short> upperMask) = Vector256.Widen(aggMask); (Vector256<int> mask1, Vector256<int> mask2) = Vector256.Widen(lowerMask); (Vector256<int> mask3, Vector256<int> mask4) = Vector256.Widen(upperMask); Vector256<uint> aggIndex = resultIndex1 | mask1.AsUInt32(); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex3 | mask3.AsUInt32()); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex4 | mask4.AsUInt32()); return (int)HorizontalAggregate<uint, MinOperator<uint>>(aggIndex); } } private static int IndexOfMinMaxVectorized512Size4Plus<T, TOperator, TInt>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> where TInt : IBinaryInteger<TInt> { Debug.Assert(sizeof(T) == 4 || sizeof(T) == 8); Debug.Assert(typeof(TInt) == typeof(uint) || typeof(TInt) == typeof(ulong)); Debug.Assert(sizeof(TInt) == sizeof(T)); // Initialize result by reading first vector and quick return if possible. Vector512<T> result = Vector512.Create(x); if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector512<T> nanMask = IsNaN(result); if (nanMask != Vector512<T>.Zero) { return IndexOfFirstMatch(nanMask); } } // Initialize indices. Vector512<TInt> indexIncrement = Vector512.Create(TInt.CreateTruncating(Vector512<TInt>.Count)); Vector512<TInt> resultIndex = Vector512<TInt>.Indices; Vector512<TInt> currentIndex = resultIndex + indexIncrement; ReadOnlySpan<T> span = x.Slice(Vector512<T>.Count); while (!span.IsEmpty) { Vector512<T> current; if (span.Length >= Vector512<T>.Count) { current = Vector512.Create(span); span = span.Slice(Vector512<T>.Count); } else { // Process a final back-shifted to cover remaining elements in x in one vector. int start = x.Length - Vector512<T>.Count; current = Vector512.Create(x.Slice(start)); currentIndex = Vector512.Create(TInt.CreateTruncating(start)) + Vector512<TInt>.Indices; span = ReadOnlySpan<T>.Empty; } // Quick return if possible. if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector512<T> nanMask = IsNaN(current); if (nanMask != Vector512<T>.Zero) { return int.CreateTruncating(currentIndex.ToScalar()) + IndexOfFirstMatch(nanMask); } } // Get mask for which lanes that should have result updated. Vector512<T> mask = TOperator.Compare(current, result); // Update result and indices. result = ElementWiseSelect(mask, current, result); resultIndex = ElementWiseSelect(mask.As<T, TInt>(), currentIndex, resultIndex); currentIndex += indexIncrement; } { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); Vector512<TInt> aggMask = ~Vector512.Equals(result.As<T, TInt>(), Vector512.Create(aggResult).As<T, TInt>()); Vector512<TInt> aggIndex = resultIndex | aggMask; return int.CreateTruncating(HorizontalAggregate<TInt, MinOperator<TInt>>(aggIndex)); } } private static int IndexOfMinMaxVectorized512Size2<T, TOperator>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> { Debug.Assert(sizeof(T) == 2); // Initialize result by reading first vector and quick return if possible. Vector512<T> result = Vector512.Create(x); if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector512<T> nanMask = IsNaN(result); if (nanMask != Vector512<T>.Zero) { return IndexOfFirstMatch(nanMask); } } // Initialize indices. Vector512<uint> indexIncrement = Vector512.Create((uint)Vector512<uint>.Count); Vector512<uint> resultIndex1 = Vector512<uint>.Indices; Vector512<uint> resultIndex2 = resultIndex1 + indexIncrement; Vector512<uint> currentIndex = resultIndex2 + indexIncrement; ReadOnlySpan<T> span = x.Slice(Vector512<T>.Count); while (!span.IsEmpty) { Vector512<T> current; if (span.Length >= Vector512<T>.Count) { current = Vector512.Create(span); span = span.Slice(Vector512<T>.Count); } else { // Process a final back-shifted to cover remaining elements in x in one vector. int start = x.Length - Vector512<T>.Count; current = Vector512.Create(x.Slice(start)); currentIndex = Vector512.Create((uint)start) + Vector512<uint>.Indices; span = ReadOnlySpan<T>.Empty; } // Quick return if possible. if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector512<T> nanMask = IsNaN(current); if (nanMask != Vector512<T>.Zero) { return (int)currentIndex.ToScalar() + IndexOfFirstMatch(nanMask); } } // Get mask for which lanes that should have result updated, also widen it for updating the indices. Vector512<T> mask = TOperator.Compare(current, result); (Vector512<int> mask1, Vector512<int> mask2) = Vector512.Widen(mask.AsInt16()); // Update result and indices. result = ElementWiseSelect(mask, current, result); resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); currentIndex += indexIncrement; resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); currentIndex += indexIncrement; } { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); Vector512<short> aggMask = ~Vector512.Equals(result.AsInt16(), Vector512.Create(aggResult).AsInt16()); (Vector512<int> mask1, Vector512<int> mask2) = Vector512.Widen(aggMask); Vector512<uint> aggIndex = resultIndex1 | mask1.AsUInt32(); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); return (int)HorizontalAggregate<uint, MinOperator<uint>>(aggIndex); } } private static int IndexOfMinMaxVectorized512Size1<T, TOperator>(ReadOnlySpan<T> x) where T : INumber<T> where TOperator : struct, IIndexOfMinMaxOperator<T> { Debug.Assert(sizeof(T) == 1); // Initialize result by reading first vector and quick return if possible. Vector512<T> result = Vector512.Create(x); if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector512<T> nanMask = IsNaN(result); if (nanMask != Vector512<T>.Zero) { return IndexOfFirstMatch(nanMask); } } // Initialize indices. Vector512<uint> indexIncrement = Vector512.Create((uint)Vector512<uint>.Count); Vector512<uint> resultIndex1 = Vector512<uint>.Indices; Vector512<uint> resultIndex2 = resultIndex1 + indexIncrement; Vector512<uint> resultIndex3 = resultIndex2 + indexIncrement; Vector512<uint> resultIndex4 = resultIndex3 + indexIncrement; Vector512<uint> currentIndex = resultIndex4 + indexIncrement; ReadOnlySpan<T> span = x.Slice(Vector512<T>.Count); while (!span.IsEmpty) { Vector512<T> current; if (span.Length >= Vector512<T>.Count) { current = Vector512.Create(span); span = span.Slice(Vector512<T>.Count); } else { // Process a final back-shifted to cover remaining elements in x in one vector. int start = x.Length - Vector512<T>.Count; current = Vector512.Create(x.Slice(start)); currentIndex = Vector512.Create((uint)start) + Vector512<uint>.Indices; span = ReadOnlySpan<T>.Empty; } // Quick return if possible. if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector512<T> nanMask = IsNaN(current); if (nanMask != Vector512<T>.Zero) { return (int)currentIndex.ToScalar() + IndexOfFirstMatch(nanMask); } } // Get mask for which lanes that should have result updated, also widen it for updating the indices. Vector512<T> mask = TOperator.Compare(current, result); (Vector512<short> lowerMask, Vector512<short> upperMask) = Vector512.Widen(mask.AsSByte()); (Vector512<int> mask1, Vector512<int> mask2) = Vector512.Widen(lowerMask); (Vector512<int> mask3, Vector512<int> mask4) = Vector512.Widen(upperMask); // Update result and indices. result = ElementWiseSelect(mask, current, result); resultIndex1 = ElementWiseSelect(mask1.AsUInt32(), currentIndex, resultIndex1); currentIndex += indexIncrement; resultIndex2 = ElementWiseSelect(mask2.AsUInt32(), currentIndex, resultIndex2); currentIndex += indexIncrement; resultIndex3 = ElementWiseSelect(mask3.AsUInt32(), currentIndex, resultIndex3); currentIndex += indexIncrement; resultIndex4 = ElementWiseSelect(mask4.AsUInt32(), currentIndex, resultIndex4); currentIndex += indexIncrement; } { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); Vector512<sbyte> aggMask = ~Vector512.Equals(result.AsSByte(), Vector512.Create(aggResult).AsSByte()); (Vector512<short> lowerMask, Vector512<short> upperMask) = Vector512.Widen(aggMask); (Vector512<int> mask1, Vector512<int> mask2) = Vector512.Widen(lowerMask); (Vector512<int> mask3, Vector512<int> mask4) = Vector512.Widen(upperMask); Vector512<uint> aggIndex = resultIndex1 | mask1.AsUInt32(); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex3 | mask3.AsUInt32()); aggIndex = MinOperator<uint>.Invoke(aggIndex, resultIndex4 | mask4.AsUInt32()); return (int)HorizontalAggregate<uint, MinOperator<uint>>(aggIndex); } } } } |