|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
namespace System.Numerics.Tensors
{
public static unsafe partial class TensorPrimitives
{
private interface IIndexOfOperator<T>
{
static abstract int Invoke(ref T result, T current, int resultIndex, int currentIndex);
static abstract void Invoke(ref Vector128<T> result, Vector128<T> current, ref Vector128<T> resultIndex, Vector128<T> currentIndex);
static abstract void Invoke(ref Vector256<T> result, Vector256<T> current, ref Vector256<T> resultIndex, Vector256<T> currentIndex);
static abstract void Invoke(ref Vector512<T> result, Vector512<T> current, ref Vector512<T> resultIndex, Vector512<T> currentIndex);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int IndexOfFinalAggregate<T, TIndexOfOperator>(Vector128<T> result, Vector128<T> resultIndex)
where TIndexOfOperator : struct, IIndexOfOperator<T>
{
Vector128<T> tmpResult;
Vector128<T> tmpIndex;
if (sizeof(T) == 8)
{
// Compare 0 with 1
tmpResult = Vector128.Shuffle(result.AsInt64(), Vector128.Create(1, 0)).As<long, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsInt64(), Vector128.Create(1, 0)).As<long, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Return 0
return (int)resultIndex.As<T, long>().ToScalar();
}
if (sizeof(T) == 4)
{
// Compare 0,1 with 2,3
tmpResult = Vector128.Shuffle(result.AsInt32(), Vector128.Create(2, 3, 0, 1)).As<int, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsInt32(), Vector128.Create(2, 3, 0, 1)).As<int, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Compare 0 with 1
tmpResult = Vector128.Shuffle(result.AsInt32(), Vector128.Create(1, 0, 3, 2)).As<int, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsInt32(), Vector128.Create(1, 0, 3, 2)).As<int, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Return 0
return resultIndex.As<T, int>().ToScalar();
}
if (sizeof(T) == 2)
{
// Compare 0,1,2,3 with 4,5,6,7
tmpResult = Vector128.Shuffle(result.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As<short, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(4, 5, 6, 7, 0, 1, 2, 3)).As<short, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Compare 0,1 with 2,3
tmpResult = Vector128.Shuffle(result.AsInt16(), Vector128.Create(2, 3, 0, 1, 4, 5, 6, 7)).As<short, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(2, 3, 0, 1, 4, 5, 6, 7)).As<short, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Compare 0 with 1
tmpResult = Vector128.Shuffle(result.AsInt16(), Vector128.Create(1, 0, 2, 3, 4, 5, 6, 7)).As<short, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsInt16(), Vector128.Create(1, 0, 2, 3, 4, 5, 6, 7)).As<short, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Return 0
return resultIndex.As<T, short>().ToScalar();
}
Debug.Assert(sizeof(T) == 1);
{
// Compare 0,1,2,3,4,5,6,7 with 8,9,10,11,12,13,14,15
tmpResult = Vector128.Shuffle(result.AsByte(), Vector128.Create((byte)8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7)).As<byte, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7)).As<byte, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Compare 0,1,2,3 with 4,5,6,7
tmpResult = Vector128.Shuffle(result.AsByte(), Vector128.Create((byte)4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Compare 0,1 with 2,3
tmpResult = Vector128.Shuffle(result.AsByte(), Vector128.Create((byte)2, 3, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)2, 3, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Compare 0 with 1
tmpResult = Vector128.Shuffle(result.AsByte(), Vector128.Create((byte)1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>();
tmpIndex = Vector128.Shuffle(resultIndex.AsByte(), Vector128.Create((byte)1, 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)).As<byte, T>();
TIndexOfOperator.Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
// Return 0
return resultIndex.As<T, byte>().ToScalar();
}
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int IndexOfFinalAggregate<T, TIndexOfOperator>(Vector256<T> result, Vector256<T> resultIndex)
where TIndexOfOperator : struct, IIndexOfOperator<T>
{
// Min the upper/lower halves of the Vector256
Vector128<T> resultLower = result.GetLower();
Vector128<T> indexLower = resultIndex.GetLower();
TIndexOfOperator.Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper());
return IndexOfFinalAggregate<T, TIndexOfOperator>(resultLower, indexLower);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int IndexOfFinalAggregate<T, TIndexOfOperator>(Vector512<T> result, Vector512<T> resultIndex)
where TIndexOfOperator : struct, IIndexOfOperator<T>
{
Vector256<T> resultLower = result.GetLower();
Vector256<T> indexLower = resultIndex.GetLower();
TIndexOfOperator.Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper());
return IndexOfFinalAggregate<T, TIndexOfOperator>(resultLower, indexLower);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector128<T> IndexLessThan<T>(Vector128<T> indices1, Vector128<T> indices2) =>
sizeof(T) == sizeof(long) ? Vector128.LessThan(indices1.AsInt64(), indices2.AsInt64()).As<long, T>() :
sizeof(T) == sizeof(int) ? Vector128.LessThan(indices1.AsInt32(), indices2.AsInt32()).As<int, T>() :
sizeof(T) == sizeof(short) ? Vector128.LessThan(indices1.AsInt16(), indices2.AsInt16()).As<short, T>() :
Vector128.LessThan(indices1.AsByte(), indices2.AsByte()).As<byte, T>();
}
}
|