File: System\Numerics\Tensors\netcore\TensorHelpers.cs
Web Access
Project: src\src\libraries\System.Numerics.Tensors\src\System.Numerics.Tensors.csproj (System.Numerics.Tensors)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
 
namespace System.Numerics.Tensors
{
    [Experimental(Experimentals.TensorTDiagId, UrlFormat = Experimentals.SharedUrlFormat)]
    internal static class TensorHelpers
    {
        /// <summary>
        /// Counts the number of true elements in a boolean filter tensor so we know how much space we will need.
        /// </summary>
        /// <param name="filter"></param>
        /// <returns>How many boolean values are true.</returns>
        public static nint CountTrueElements(scoped in ReadOnlyTensorSpan<bool> filter)
        {
            Span<bool> filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)filter._shape._memoryLength);
            nint count = 0;
            for (int i = 0; i < filterSpan.Length; i++)
            {
                if (filterSpan[i])
                    count++;
            }
 
            return count;
        }
 
        internal static bool IsBroadcastableTo<T>(Tensor<T> tensor1, Tensor<T> tensor2)
            where T : IEquatable<T>, IEqualityOperators<T, T, bool> => IsBroadcastableTo(tensor1.Lengths, tensor2.Lengths);
 
        internal static bool IsBroadcastableTo(ReadOnlySpan<nint> lengths1, ReadOnlySpan<nint> lengths2)
        {
            int lengths1Index = lengths1.Length - 1;
            int lengths2Index = lengths2.Length - 1;
 
            bool areCompatible = true;
 
            nint s1;
            nint s2;
 
            if (lengths1.Length == 0 || lengths2.Length == 0)
                return false;
 
            while (lengths1Index >= 0 || lengths2Index >= 0)
            {
                // if a dimension is missing in one of the shapes, it is considered to be 1
                if (lengths1Index < 0)
                    s1 = 1;
                else
                    s1 = lengths1[lengths1Index--];
 
                if (lengths2Index < 0)
                    s2 = 1;
                else
                    s2 = lengths2[lengths2Index--];
 
                if (s1 == s2 || (s1 == 1 && s2 > 1) || (s2 == 1 && s1 > 1)) { }
                else
                {
                    areCompatible = false;
                    break;
                }
            }
 
            return areCompatible;
        }
 
        internal static nint[] GetIntermediateShape(ReadOnlySpan<nint> shape1, int shape2Length)
        {
            int shape1Index = shape1.Length - 1;
            int newShapeIndex = Math.Max(shape1.Length, shape2Length) - 1;
            nint[] newShape = new nint[Math.Max(shape1.Length, shape2Length)];
 
            while (newShapeIndex >= 0)
            {
                // if a dimension is missing in one of the shapes, it is considered to be 1
                if (shape1Index < 0)
                    newShape[newShapeIndex--] = 1;
                else
                    newShape[newShapeIndex--] = shape1[shape1Index--];
            }
 
            return newShape;
        }
 
        internal static bool IsUnderlyingStorageSameSize<T>(scoped in ReadOnlyTensorSpan<T> tensor1, scoped in ReadOnlyTensorSpan<T> tensor2)
            => tensor1._shape._memoryLength == tensor2._shape._memoryLength;
 
        internal static bool IsUnderlyingStorageSameSize<T>(Tensor<T> tensor1, Tensor<T> tensor2)
    => tensor1._values.Length == tensor2._values.Length;
 
        internal static bool AreLengthsTheSame<T>(scoped in ReadOnlyTensorSpan<T> tensor1, scoped in ReadOnlyTensorSpan<T> tensor2)
            => tensor1.Lengths.SequenceEqual(tensor2.Lengths);
 
        internal static bool AreLengthsTheSame(ReadOnlySpan<nint> lengths1, ReadOnlySpan<nint> lengths2)
            => lengths1.SequenceEqual(lengths2);
 
        internal static bool IsContiguousAndDense<T>(scoped in ReadOnlyTensorSpan<T> tensor)
        {
            // Right most dimension must be 1 for a dense tensor.
            if (tensor._shape.Strides[^1] != 1)
                return false;
 
            // For other dimensions, the stride must be equal to the product of the dimensions to the right.
            for (int i = tensor._shape._rank - 2; i >= 0; i--)
            {
                if (tensor._shape.Strides[i] != TensorPrimitives.Product(tensor.Lengths.Slice(i + 1, tensor.Lengths.Length - i - 1)))
                    return false;
            }
            return true;
        }
 
        internal static void PermuteIndices(Span<nint> indices, Span<nint> permutedIndices, ReadOnlySpan<int> permutation)
        {
            for (int i = 0; i < indices.Length; i++)
            {
                permutedIndices[i] = indices[permutation[i]];
            }
        }
    }
}