File: System\Numerics\Tensors\netcore\TensorPrimitives.CosineSimilarity.cs
Web Access
Project: src\src\libraries\System.Numerics.Tensors\src\System.Numerics.Tensors.csproj (System.Numerics.Tensors)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
 
namespace System.Numerics.Tensors
{
    public partial class TensorPrimitives
    {
        /// <summary>Computes the cosine similarity between the two specified non-empty, equal-length tensors of numbers.</summary>
        /// <param name="x">The first tensor, represented as a span.</param>
        /// <param name="y">The second tensor, represented as a span.</param>
        /// <returns>The cosine similarity of the two tensors.</returns>
        /// <exception cref="ArgumentException">Length of <paramref name="x" /> must be same as length of <paramref name="y" />.</exception>
        /// <exception cref="ArgumentException"><paramref name="x" /> and <paramref name="y" /> must not be empty.</exception>
        /// <remarks>
        /// <para>
        /// This method effectively computes <c>TensorPrimitives.Dot(x, y) / (<typeparamref name="T"/>.Sqrt(TensorPrimitives.SumOfSquares(x)) * <typeparamref name="T"/>.Sqrt(TensorPrimitives.SumOfSquares(y)).</c>
        /// </para>
        /// <para>
        /// If any element in either input tensor is equal to <see cref="IFloatingPointIeee754{TSelf}.NegativeInfinity"/>, <see cref="IFloatingPointIeee754{TSelf}.PositiveInfinity"/>, or <see cref="IFloatingPointIeee754{TSelf}.NaN"/>,
        /// NaN is returned.
        /// </para>
        /// <para>
        /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
        /// operating systems or architectures.
        /// </para>
        /// </remarks>
        public static T CosineSimilarity<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y)
            where T : IRootFunctions<T>
        {
            if (typeof(T) == typeof(Half))
            {
                // Half is not implicitly vectorizable, but we can do so manually.
                return (T)(object)CosineSimilarityHalfCore(Rename<T, Half>(x), Rename<T, Half>(y));
            }
 
            return CosineSimilarityCore(x, y);
        }
 
        /// <summary>Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers.</summary>
        /// <remarks>Assumes arguments have already been validated to be non-empty and equal length.</remarks>
        private static T CosineSimilarityCore<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IRootFunctions<T>
        {
            if (x.IsEmpty)
            {
                ThrowHelper.ThrowArgument_SpansMustBeNonEmpty();
            }
 
            if (x.Length != y.Length)
            {
                ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
            }
 
            // Compute the same as:
            // TensorPrimitives.Dot(x, y) / (Math.Sqrt(TensorPrimitives.SumOfSquares(x)) * Math.Sqrt(TensorPrimitives.SumOfSquares(y)))
            // but only looping over each span once.
 
            if (Vector512.IsHardwareAccelerated && Vector512<T>.IsSupported && x.Length >= Vector512<T>.Count)
            {
                ref T xRef = ref MemoryMarshal.GetReference(x);
                ref T yRef = ref MemoryMarshal.GetReference(y);
 
                Vector512<T> dotProductVector = Vector512<T>.Zero;
                Vector512<T> xSumOfSquaresVector = Vector512<T>.Zero;
                Vector512<T> ySumOfSquaresVector = Vector512<T>.Zero;
 
                // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining.
                int oneVectorFromEnd = x.Length - Vector512<T>.Count;
                int i = 0;
                do
                {
                    Update(
                        Vector512.LoadUnsafe(ref xRef, (uint)i),
                        Vector512.LoadUnsafe(ref yRef, (uint)i),
                        ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
 
                    i += Vector512<T>.Count;
                }
                while (i <= oneVectorFromEnd);
 
                // Process the last vector in the span, masking off elements already processed.
                if (i != x.Length)
                {
                    Vector512<T> remainderMask = CreateRemainderMaskVector512<T>(x.Length - i);
 
                    Update(
                        Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512<T>.Count)) & remainderMask,
                        Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512<T>.Count)) & remainderMask,
                        ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                }
 
                return Finalize(dotProductVector, xSumOfSquaresVector, ySumOfSquaresVector);
            }
 
            if (Vector256.IsHardwareAccelerated && Vector256<T>.IsSupported && x.Length >= Vector256<T>.Count)
            {
                ref T xRef = ref MemoryMarshal.GetReference(x);
                ref T yRef = ref MemoryMarshal.GetReference(y);
 
                Vector256<T> dotProductVector = Vector256<T>.Zero;
                Vector256<T> xSumOfSquaresVector = Vector256<T>.Zero;
                Vector256<T> ySumOfSquaresVector = Vector256<T>.Zero;
 
                // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining.
                int oneVectorFromEnd = x.Length - Vector256<T>.Count;
                int i = 0;
                do
                {
                    Update(
                        Vector256.LoadUnsafe(ref xRef, (uint)i),
                        Vector256.LoadUnsafe(ref yRef, (uint)i),
                        ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
 
                    i += Vector256<T>.Count;
                }
                while (i <= oneVectorFromEnd);
 
                // Process the last vector in the span, masking off elements already processed.
                if (i != x.Length)
                {
                    Vector256<T> remainderMask = CreateRemainderMaskVector256<T>(x.Length - i);
 
                    Update(
                        Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256<T>.Count)) & remainderMask,
                        Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256<T>.Count)) & remainderMask,
                        ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                }
 
                return Finalize(dotProductVector, xSumOfSquaresVector, ySumOfSquaresVector);
            }
 
            if (Vector128.IsHardwareAccelerated && Vector128<T>.IsSupported && x.Length >= Vector128<T>.Count)
            {
                ref T xRef = ref MemoryMarshal.GetReference(x);
                ref T yRef = ref MemoryMarshal.GetReference(y);
 
                Vector128<T> dotProductVector = Vector128<T>.Zero;
                Vector128<T> xSumOfSquaresVector = Vector128<T>.Zero;
                Vector128<T> ySumOfSquaresVector = Vector128<T>.Zero;
 
                // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining.
                int oneVectorFromEnd = x.Length - Vector128<T>.Count;
                int i = 0;
                do
                {
                    Update(
                        Vector128.LoadUnsafe(ref xRef, (uint)i),
                        Vector128.LoadUnsafe(ref yRef, (uint)i),
                        ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
 
                    i += Vector128<T>.Count;
                }
                while (i <= oneVectorFromEnd);
 
                // Process the last vector in the span, masking off elements already processed.
                if (i != x.Length)
                {
                    Vector128<T> remainderMask = CreateRemainderMaskVector128<T>(x.Length - i);
 
                    Update(
                        Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128<T>.Count)) & remainderMask,
                        Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128<T>.Count)) & remainderMask,
                        ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                }
 
                return Finalize(dotProductVector, xSumOfSquaresVector, ySumOfSquaresVector);
            }
 
            // Vectorization isn't supported or there are too few elements to vectorize.
            // Use a scalar implementation.
            T dotProduct = T.Zero, xSumOfSquares = T.Zero, ySumOfSquares = T.Zero;
            for (int i = 0; i < x.Length; i++)
            {
                Update(x[i], y[i], ref dotProduct, ref xSumOfSquares, ref ySumOfSquares);
            }
 
            return Finalize(dotProduct, xSumOfSquares, ySumOfSquares);
        }
 
        /// <summary>Provides the same implementation as <see cref="CosineSimilarityCore"/>, but specifically for <see cref="Half"/>.</summary>
        private static Half CosineSimilarityHalfCore(ReadOnlySpan<Half> x, ReadOnlySpan<Half> y)
        {
            if (x.IsEmpty)
            {
                ThrowHelper.ThrowArgument_SpansMustBeNonEmpty();
            }
 
            if (x.Length != y.Length)
            {
                ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
            }
 
            // Compute the same as:
            // TensorPrimitives.Dot(x, y) / (Math.Sqrt(TensorPrimitives.SumOfSquares(x)) * Math.Sqrt(TensorPrimitives.SumOfSquares(y)))
            // but only looping over each span once. As Half can't be vectorized implicitly, we reinterpret as shorts, as then
            // widen to float vectors, which can be vectorized.
 
            if (Vector512.IsHardwareAccelerated && x.Length >= Vector512<short>.Count)
            {
                ref short xRef = ref Unsafe.As<Half, short>(ref MemoryMarshal.GetReference(x));
                ref short yRef = ref Unsafe.As<Half, short>(ref MemoryMarshal.GetReference(y));
 
                // Accumulate with float vectors, casting at the very end back to Half.
                Vector512<float> dotProductVector = Vector512<float>.Zero;
                Vector512<float> xSumOfSquaresVector = Vector512<float>.Zero;
                Vector512<float> ySumOfSquaresVector = Vector512<float>.Zero;
 
                // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining.
                int oneVectorFromEnd = x.Length - Vector512<short>.Count;
                int i = 0;
                do
                {
                    (Vector512<float> xVecLower, Vector512<float> xVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i));
                    (Vector512<float> yVecLower, Vector512<float> yVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(Vector512.LoadUnsafe(ref yRef, (uint)i));
 
                    Update(xVecLower, yVecLower, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                    Update(xVecUpper, yVecUpper, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
 
                    i += Vector512<short>.Count;
                }
                while (i <= oneVectorFromEnd);
 
                // Process the last vector in the span, masking off elements already processed.
                if (i != x.Length)
                {
                    Vector512<short> remainderMask = CreateRemainderMaskVector512<short>(x.Length - i);
 
                    (Vector512<float> xVecLower, Vector512<float> xVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(
                        Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512<short>.Count)) & remainderMask);
                    (Vector512<float> yVecLower, Vector512<float> yVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(
                        Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512<short>.Count)) & remainderMask);
 
                    Update(xVecLower, yVecLower, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                    Update(xVecUpper, yVecUpper, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                }
 
                return (Half)Finalize(dotProductVector, xSumOfSquaresVector, ySumOfSquaresVector);
            }
 
            if (Vector256.IsHardwareAccelerated && x.Length >= Vector256<short>.Count)
            {
                ref short xRef = ref Unsafe.As<Half, short>(ref MemoryMarshal.GetReference(x));
                ref short yRef = ref Unsafe.As<Half, short>(ref MemoryMarshal.GetReference(y));
 
                // Accumulate with float vectors, casting at the very end back to Half.
                Vector256<float> dotProductVector = Vector256<float>.Zero;
                Vector256<float> xSumOfSquaresVector = Vector256<float>.Zero;
                Vector256<float> ySumOfSquaresVector = Vector256<float>.Zero;
 
                // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining.
                int oneVectorFromEnd = x.Length - Vector256<short>.Count;
                int i = 0;
                do
                {
                    (Vector256<float> xVecLower, Vector256<float> xVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i));
                    (Vector256<float> yVecLower, Vector256<float> yVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(Vector256.LoadUnsafe(ref yRef, (uint)i));
 
                    Update(xVecLower, yVecLower, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                    Update(xVecUpper, yVecUpper, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
 
                    i += Vector256<short>.Count;
                }
                while (i <= oneVectorFromEnd);
 
                // Process the last vector in the span, masking off elements already processed.
                if (i != x.Length)
                {
                    Vector256<short> remainderMask = CreateRemainderMaskVector256<short>(x.Length - i);
 
                    (Vector256<float> xVecLower, Vector256<float> xVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(
                        Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256<short>.Count)) & remainderMask);
                    (Vector256<float> yVecLower, Vector256<float> yVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(
                        Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256<short>.Count)) & remainderMask);
 
                    Update(xVecLower, yVecLower, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                    Update(xVecUpper, yVecUpper, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                }
 
                return (Half)Finalize(dotProductVector, xSumOfSquaresVector, ySumOfSquaresVector);
            }
 
            if (Vector128.IsHardwareAccelerated && x.Length >= Vector128<short>.Count)
            {
                ref short xRef = ref Unsafe.As<Half, short>(ref MemoryMarshal.GetReference(x));
                ref short yRef = ref Unsafe.As<Half, short>(ref MemoryMarshal.GetReference(y));
 
                // Accumulate with float vectors, casting at the very end back to Half.
                Vector128<float> dotProductVector = Vector128<float>.Zero;
                Vector128<float> xSumOfSquaresVector = Vector128<float>.Zero;
                Vector128<float> ySumOfSquaresVector = Vector128<float>.Zero;
 
                // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining.
                int oneVectorFromEnd = x.Length - Vector128<short>.Count;
                int i = 0;
                do
                {
                    (Vector128<float> xVecLower, Vector128<float> xVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i));
                    (Vector128<float> yVecLower, Vector128<float> yVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(Vector128.LoadUnsafe(ref yRef, (uint)i));
 
                    Update(xVecLower, yVecLower, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                    Update(xVecUpper, yVecUpper, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
 
                    i += Vector128<short>.Count;
                }
                while (i <= oneVectorFromEnd);
 
                // Process the last vector in the span, masking off elements already processed.
                if (i != x.Length)
                {
                    Vector128<short> remainderMask = CreateRemainderMaskVector128<short>(x.Length - i);
 
                    (Vector128<float> xVecLower, Vector128<float> xVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(
                        Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128<short>.Count)) & remainderMask);
                    (Vector128<float> yVecLower, Vector128<float> yVecUpper) = WidenHalfAsInt16ToSingleOperator.Invoke(
                        Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128<short>.Count)) & remainderMask);
 
                    Update(xVecLower, yVecLower, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                    Update(xVecUpper, yVecUpper, ref dotProductVector, ref xSumOfSquaresVector, ref ySumOfSquaresVector);
                }
 
                return (Half)Finalize(dotProductVector, xSumOfSquaresVector, ySumOfSquaresVector);
            }
 
            // Vectorization isn't supported or there are too few elements to vectorize.
            // Use a scalar implementation.
            float dotProduct = 0, xSumOfSquares = 0, ySumOfSquares = 0;
            for (int i = 0; i < x.Length; i++)
            {
                Update((float)x[i], (float)y[i], ref dotProduct, ref xSumOfSquares, ref ySumOfSquares);
            }
 
            return (Half)Finalize(dotProduct, xSumOfSquares, ySumOfSquares);
        }
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static void Update<T>(T x, T y, ref T dotProduct, ref T xSumOfSquares, ref T ySumOfSquares) where T : INumberBase<T>
        {
            dotProduct = MultiplyAddEstimateOperator<T>.Invoke(x, y, dotProduct);
            xSumOfSquares = MultiplyAddEstimateOperator<T>.Invoke(x, x, xSumOfSquares);
            ySumOfSquares = MultiplyAddEstimateOperator<T>.Invoke(y, y, ySumOfSquares);
        }
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static void Update<T>(Vector128<T> xVec, Vector128<T> yVec, ref Vector128<T> dotProductVector, ref Vector128<T> xSumOfSquaresVector, ref Vector128<T> ySumOfSquaresVector) where T : INumberBase<T>
        {
            dotProductVector = MultiplyAddEstimateOperator<T>.Invoke(xVec, yVec, dotProductVector);
            xSumOfSquaresVector = MultiplyAddEstimateOperator<T>.Invoke(xVec, xVec, xSumOfSquaresVector);
            ySumOfSquaresVector = MultiplyAddEstimateOperator<T>.Invoke(yVec, yVec, ySumOfSquaresVector);
        }
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static void Update<T>(Vector256<T> xVec, Vector256<T> yVec, ref Vector256<T> dotProductVector, ref Vector256<T> xSumOfSquaresVector, ref Vector256<T> ySumOfSquaresVector) where T : INumberBase<T>
        {
            dotProductVector = MultiplyAddEstimateOperator<T>.Invoke(xVec, yVec, dotProductVector);
            xSumOfSquaresVector = MultiplyAddEstimateOperator<T>.Invoke(xVec, xVec, xSumOfSquaresVector);
            ySumOfSquaresVector = MultiplyAddEstimateOperator<T>.Invoke(yVec, yVec, ySumOfSquaresVector);
        }
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static void Update<T>(Vector512<T> xVec, Vector512<T> yVec, ref Vector512<T> dotProductVector, ref Vector512<T> xSumOfSquaresVector, ref Vector512<T> ySumOfSquaresVector) where T : INumberBase<T>
        {
            dotProductVector = MultiplyAddEstimateOperator<T>.Invoke(xVec, yVec, dotProductVector);
            xSumOfSquaresVector = MultiplyAddEstimateOperator<T>.Invoke(xVec, xVec, xSumOfSquaresVector);
            ySumOfSquaresVector = MultiplyAddEstimateOperator<T>.Invoke(yVec, yVec, ySumOfSquaresVector);
        }
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static T Finalize<T>(T dotProduct, T xSumOfSquares, T ySumOfSquares) where T : IRootFunctions<T> =>
            dotProduct / (T.Sqrt(xSumOfSquares) * T.Sqrt(ySumOfSquares));
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static T Finalize<T>(Vector128<T> dotProductVector, Vector128<T> xSumOfSquaresVector, Vector128<T> ySumOfSquaresVector) where T : IRootFunctions<T> =>
            Vector128.Sum(dotProductVector) / (T.Sqrt(Vector128.Sum(xSumOfSquaresVector)) * T.Sqrt(Vector128.Sum(ySumOfSquaresVector)));
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static T Finalize<T>(Vector256<T> dotProductVector, Vector256<T> xSumOfSquaresVector, Vector256<T> ySumOfSquaresVector) where T : IRootFunctions<T> =>
            Vector256.Sum(dotProductVector) / (T.Sqrt(Vector256.Sum(xSumOfSquaresVector)) * T.Sqrt(Vector256.Sum(ySumOfSquaresVector)));
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private static T Finalize<T>(Vector512<T> dotProductVector, Vector512<T> xSumOfSquaresVector, Vector512<T> ySumOfSquaresVector) where T : IRootFunctions<T> =>
            Vector512.Sum(dotProductVector) / (T.Sqrt(Vector512.Sum(xSumOfSquaresVector)) * T.Sqrt(Vector512.Sum(ySumOfSquaresVector)));
    }
}