File: System\Numerics\Tensors\netcore\Tensor_1.cs
Web Access
Project: src\src\libraries\System.Numerics.Tensors\src\System.Numerics.Tensors.csproj (System.Numerics.Tensors)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
 
namespace System.Numerics.Tensors
{
    /// <summary>
    /// Represents a tensor.
    /// </summary>
    [Experimental(Experimentals.TensorTDiagId, UrlFormat = Experimentals.SharedUrlFormat)]
    public sealed class Tensor<T> : ITensor<Tensor<T>, T>
    {
        /// <summary>Gets an empty tensor.</summary>
        public static Tensor<T> Empty { get; } = new();
 
        internal readonly TensorShape _shape;
        internal readonly T[] _values;
 
        internal readonly int _start;
        internal readonly bool _isPinned;
 
        internal Tensor(scoped ReadOnlySpan<nint> lengths, bool pinned)
        {
            _shape = TensorShape.Create(lengths);
            _values = GC.AllocateArray<T>(checked((int)(_shape.LinearLength)), pinned);
 
            _start = 0;
            _isPinned = pinned;
        }
 
        internal Tensor(scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, bool pinned)
        {
            _shape = TensorShape.Create(lengths, strides);
            _values = GC.AllocateArray<T>(checked((int)(_shape.LinearLength)), pinned);
 
            _start = 0;
            _isPinned = pinned;
        }
 
        internal Tensor(T[]? array)
        {
            _shape = TensorShape.Create(array);
            _values = (array is not null) ? array : [];
 
            _start = 0;
            _isPinned = false;
        }
 
        internal Tensor(T[]? array, scoped ReadOnlySpan<nint> lengths)
        {
            _shape = TensorShape.Create(array, lengths);
            _values = (array is not null) ? array : [];
 
            _start = 0;
            _isPinned = false;
        }
 
        internal Tensor(T[]? array, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides)
        {
            _shape = TensorShape.Create(array, lengths, strides);
            _values = (array is not null) ? array : [];
 
            _start = 0;
            _isPinned = false;
        }
 
        internal Tensor(T[]? array, int start, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides)
        {
            _shape = TensorShape.Create(array, start, lengths, strides);
            _values = (array is not null) ? array : [];
 
            _start = start;
            _isPinned = false;
        }
 
        internal Tensor(T[]? array, int start, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, scoped ReadOnlySpan<int> linearRankOrder)
        {
            _shape = TensorShape.Create(array, start, lengths, strides, linearRankOrder);
            _values = (array is not null) ? array : [];
 
            _start = start;
            _isPinned = false;
        }
 
        internal Tensor(T[] array, in TensorShape shape, bool isPinned)
        {
            ThrowHelper.ThrowIfArrayTypeMismatch<T>(array);
 
            _shape = shape;
            _values = array;
 
            _start = 0;
            _isPinned = isPinned;
        }
 
        internal Tensor(T[] array, int start, in TensorShape shape, bool isPinned)
        {
            ThrowHelper.ThrowIfArrayTypeMismatch<T>(array);
 
            _shape = shape;
            _values = array;
 
            _start = start;
            _isPinned = isPinned;
        }
 
        private Tensor()
        {
            _shape = default;
            _values = [];
 
            _start = 0;
            _isPinned = false;
        }
 
        /// <inheritdoc cref="TensorSpan{T}.this[ReadOnlySpan{nint}]" />
        public ref T this[params scoped ReadOnlySpan<nint> indexes]
        {
            get => ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(_values), _start + _shape.GetLinearOffset<TensorShape.GetOffsetAndLengthForNInt, nint>(indexes));
        }
 
        /// <inheritdoc cref="TensorSpan{T}.this[ReadOnlySpan{NIndex}]" />
        public ref T this[params scoped ReadOnlySpan<NIndex> indexes]
        {
            get => ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(_values), _start + _shape.GetLinearOffset<TensorShape.GetOffsetAndLengthForNIndex, NIndex>(indexes));
        }
 
        /// <inheritdoc cref="TensorSpan{T}.this[ReadOnlySpan{NRange}]" />
        public Tensor<T> this[params ReadOnlySpan<NRange> ranges]
        {
            get => Slice(ranges);
            set => value.CopyTo(Slice(ranges));
        }
 
        /// <inheritdoc cref="IReadOnlyTensor.FlattenedLength" />
        public nint FlattenedLength => _shape.FlattenedLength;
 
        internal bool IsContiguousAndDense => _shape.IsContiguousAndDense;
 
        /// <inheritdoc cref="IReadOnlyTensor.IsEmpty" />
        public bool IsEmpty => _shape.IsEmpty;
 
        /// <inheritdoc cref="IReadOnlyTensor.IsPinned" />
        public bool IsPinned => _isPinned;
 
        /// <inheritdoc cref="IReadOnlyTensor.Lengths" />
        public ReadOnlySpan<nint> Lengths => _shape.Lengths;
 
        /// <inheritdoc cref="IReadOnlyTensor.Rank" />
        public int Rank => _shape.Rank;
 
        /// <inheritdoc cref="IReadOnlyTensor.Strides" />
        public ReadOnlySpan<nint> Strides => _shape.Strides;
 
        /// <summary>Defines an implicit conversion of an array to a tensor.</summary>
        /// <param name="array">The array to convert to a tensor.</param>
        /// <returns>The tensor span that corresponds to <paramref name="array" />.</returns>
        public static implicit operator Tensor<T>(T[] array) => Tensor.Create(array);
 
        /// <summary>Defines an implicit conversion of a tensor to a tensor span.</summary>
        /// <param name="tensor">The tensor to convert to a tensor span.</param>
        /// <returns>The tensor that corresponds to <paramref name="tensor" />.</returns>
        public static implicit operator TensorSpan<T>(Tensor<T> tensor) => tensor.AsTensorSpan();
 
        /// <inheritdoc cref="TensorSpan{T}.implicit operator ReadOnlyTensorSpan{T}(in TensorSpan{T})" />
        public static implicit operator ReadOnlyTensorSpan<T>(Tensor<T> tensor) => tensor.AsReadOnlyTensorSpan();
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.AsReadOnlyTensorSpan()" />
        public ReadOnlyTensorSpan<T> AsReadOnlyTensorSpan() => new ReadOnlyTensorSpan<T>(ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(_values), _start), in _shape);
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.AsReadOnlyTensorSpan(ReadOnlySpan{nint})" />
        public ReadOnlyTensorSpan<T> AsReadOnlyTensorSpan(params scoped ReadOnlySpan<nint> startIndexes) => AsReadOnlyTensorSpan().Slice(startIndexes);
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.AsReadOnlyTensorSpan(ReadOnlySpan{NIndex})" />
        public ReadOnlyTensorSpan<T> AsReadOnlyTensorSpan(params scoped ReadOnlySpan<NIndex> startIndexes) => AsReadOnlyTensorSpan().Slice(startIndexes);
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.AsReadOnlyTensorSpan(ReadOnlySpan{NRange})" />
        public ReadOnlyTensorSpan<T> AsReadOnlyTensorSpan(params scoped ReadOnlySpan<NRange> ranges) => AsReadOnlyTensorSpan().Slice(ranges);
 
        /// <inheritdoc cref="ITensor{TSelf, T}.AsTensorSpan()" />
        public TensorSpan<T> AsTensorSpan() => new TensorSpan<T>(ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(_values), _start), in _shape);
 
        /// <inheritdoc cref="ITensor{TSelf, T}.AsTensorSpan(ReadOnlySpan{nint})" />
        public TensorSpan<T> AsTensorSpan(params scoped ReadOnlySpan<nint> startIndexes) => AsTensorSpan().Slice(startIndexes);
 
        /// <inheritdoc cref="ITensor{TSelf, T}.AsTensorSpan(ReadOnlySpan{NIndex})" />
        public TensorSpan<T> AsTensorSpan(params scoped ReadOnlySpan<NIndex> startIndexes) => AsTensorSpan().Slice(startIndexes);
 
        /// <inheritdoc cref="ITensor{TSelf, T}.AsTensorSpan(ReadOnlySpan{NRange})" />
        public TensorSpan<T> AsTensorSpan(params scoped ReadOnlySpan<NRange> ranges) => AsTensorSpan().Slice(ranges);
 
        /// <inheritdoc cref="ITensor.Clear()" />
        public unsafe void Clear() => AsTensorSpan().Clear();
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.CopyTo(in TensorSpan{T})" />
        public void CopyTo(scoped in TensorSpan<T> destination)
        {
            if (!TryCopyTo(destination))
            {
                ThrowHelper.ThrowArgument_DestinationTooShort();
            }
        }
 
        /// <inheritdoc cref="ITensor{TSelf, T}.Fill(T)" />
        public void Fill(T value) => AsTensorSpan().Fill(value);
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.FlattenTo(Span{T})" />
        public void FlattenTo(scoped Span<T> destination)
        {
            if (!TryFlattenTo(destination))
            {
                ThrowHelper.ThrowArgument_DestinationTooShort();
            }
        }
 
        /// <summary>Gets an enumerator for the readonly tensor.</summary>
        public Enumerator GetEnumerator() => new Enumerator(this);
 
        /// <inheritdoc cref="ITensor{TSelf, T}.GetPinnableReference()" />
        [EditorBrowsable(EditorBrowsableState.Never)]
        public ref T GetPinnableReference()
        {
            // Ensure that the native code has just one forward branch that is predicted-not-taken.
            ref T ret = ref Unsafe.NullRef<T>();
            if (_shape.FlattenedLength != 0) ret = ref MemoryMarshal.GetArrayDataReference(_values);
            return ref ret;
        }
 
        /// <inheritdoc cref="IReadOnlyTensor.GetPinnedHandle()" />
        public unsafe MemoryHandle GetPinnedHandle()
        {
            GCHandle handle = GCHandle.Alloc(_values, GCHandleType.Pinned);
            return new MemoryHandle(Unsafe.AsPointer(ref GetPinnableReference()), handle);
        }
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.Slice(ReadOnlySpan{nint})" />
        public Tensor<T> Slice(params ReadOnlySpan<nint> startIndexes)
        {
            TensorShape shape = _shape.Slice<TensorShape.GetOffsetAndLengthForNInt, nint>(startIndexes, out nint linearOffset);
 
            // The source tensor can have no more than int.MaxValue elements so linearOffset will always be in range of int.
            Debug.Assert((int)(linearOffset) == linearOffset);
 
            return new Tensor<T>(
                _values,
                (int)(_start + linearOffset),
                in shape,
                _isPinned
            );
        }
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.Slice(ReadOnlySpan{NIndex})" />
        public Tensor<T> Slice(params ReadOnlySpan<NIndex> startIndexes)
        {
            TensorShape shape = _shape.Slice<TensorShape.GetOffsetAndLengthForNIndex, NIndex>(startIndexes, out nint linearOffset);
 
            // The source tensor can have no more than int.MaxValue elements so linearOffset will always be in range of int.
            Debug.Assert((int)(linearOffset) == linearOffset);
 
            return new Tensor<T>(
                _values,
                (int)(_start + linearOffset),
                in shape,
                _isPinned
            );
        }
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.Slice(ReadOnlySpan{NRange})" />
        public Tensor<T> Slice(params ReadOnlySpan<NRange> ranges)
        {
            TensorShape shape = _shape.Slice<TensorShape.GetOffsetAndLengthForNRange, NRange>(ranges, out nint linearOffset);
 
            // The source tensor can have no more than int.MaxValue elements so linearOffset will always be in range of int.
            Debug.Assert((int)(linearOffset) == linearOffset);
 
            return new Tensor<T>(
                _values,
                (int)(_start + linearOffset),
                in shape,
                _isPinned
            );
        }
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.TryCopyTo(in TensorSpan{T})" />
        public bool TryCopyTo(scoped in TensorSpan<T> destination) => AsReadOnlyTensorSpan().TryCopyTo(destination);
 
        /// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.TryFlattenTo(Span{T})" />
        public bool TryFlattenTo(scoped Span<T> destination) => AsReadOnlyTensorSpan().TryFlattenTo(destination);
 
        /// <summary>
        /// Creates a <see cref="string"/> representation of the <see cref="TensorSpan{T}"/>."/>
        /// </summary>
        /// <param name="maximumLengths">Maximum Length of each dimension</param>
        /// <returns>A <see cref="string"/> representation of the <see cref="Tensor{T}"/></returns>
        public string ToString(params ReadOnlySpan<nint> maximumLengths)
        {
            var sb = new StringBuilder($"System.Numerics.Tensors.Tensor<{typeof(T).Name}>[{_shape}]");
 
            sb.AppendLine("{");
            Tensor.ToString(AsReadOnlyTensorSpan(), maximumLengths, sb);
            sb.AppendLine("}");
 
            return sb.ToString();
        }
 
        //
        // IEnumerable
        //
 
        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
 
        //
        // IEnumerable<T>
        //
 
        IEnumerator<T> IEnumerable<T>.GetEnumerator() => GetEnumerator();
 
        //
        // IReadOnlyTensor
        //
 
        object? IReadOnlyTensor.this[params scoped ReadOnlySpan<NIndex> indexes] => this[indexes];
 
        object? IReadOnlyTensor.this[params scoped ReadOnlySpan<nint> indexes] => this[indexes];
 
        //
        // IReadOnlyTensor<TSelf, T>
        //
 
        ref readonly T IReadOnlyTensor<Tensor<T>, T>.this[params ReadOnlySpan<nint> indexes] => ref this[indexes];
 
        ref readonly T IReadOnlyTensor<Tensor<T>, T>.this[params ReadOnlySpan<NIndex> indexes] => ref this[indexes];
 
        [EditorBrowsable(EditorBrowsableState.Never)]
        ref readonly T IReadOnlyTensor<Tensor<T>, T>.GetPinnableReference() => ref GetPinnableReference();
 
        //
        // ITensor
        //
 
        bool ITensor.IsReadOnly => false;
 
        object? ITensor.this[params scoped ReadOnlySpan<NIndex> indexes]
        {
            get => this[indexes];
 
            set
            {
                this[indexes] = (T)value!;
            }
        }
 
        object? ITensor.this[params scoped ReadOnlySpan<nint> indexes]
        {
            get => this[indexes];
 
            set
            {
                this[indexes] = (T)value!;
            }
        }
 
        void ITensor.Fill(object value) => Fill(value is T t ? t : throw new ArgumentException($"Cannot convert {value} to {typeof(T)}"));
 
        //
        // ITensor<TSelf, T>
        //
 
        static Tensor<T> ITensor<Tensor<T>, T>.Create(scoped ReadOnlySpan<nint> lengths, bool pinned) => Tensor.Create<T>(lengths, pinned);
 
        static Tensor<T> ITensor<Tensor<T>, T>.Create(scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, bool pinned) => Tensor.Create<T>(lengths, strides, pinned);
 
        static Tensor<T> ITensor<Tensor<T>, T>.CreateUninitialized(scoped ReadOnlySpan<nint> lengths, bool pinned) => Tensor.Create<T>(lengths, pinned);
 
        static Tensor<T> ITensor<Tensor<T>, T>.CreateUninitialized(scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, bool pinned) => Tensor.Create<T>(lengths, strides, pinned);
 
        /// <summary>Enumerates the elements of a tensor.</summary>
        public struct Enumerator : IEnumerator<T>
        {
            private readonly Tensor<T> _tensor;
            private nint[] _indexes;
            private nint _linearOffset;
            private nint _itemsEnumerated;
 
            internal Enumerator(Tensor<T> tensor)
            {
                _tensor = tensor;
                _indexes = new nint[tensor.Rank];
 
                _indexes[^1] = -1;
 
                _linearOffset = tensor._start - (!tensor.IsEmpty ? tensor.Strides[^1] : 0);
                _itemsEnumerated = 0;
            }
 
            /// <inheritdoc cref="IEnumerator{T}.Current" />
            public readonly ref T Current => ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(_tensor._values), _linearOffset);
 
            /// <inheritdoc cref="IEnumerator.MoveNext()" />
            public bool MoveNext()
            {
                if (_itemsEnumerated == _tensor._shape.FlattenedLength)
                {
                    return false;
                }
 
                _linearOffset = _tensor._shape.AdjustToNextIndex(_tensor._shape, _linearOffset, _indexes);
 
                _itemsEnumerated++;
                return true;
            }
 
            /// <inheritdoc cref="IEnumerator.Reset()" />
            public void Reset()
            {
                Array.Clear(_indexes);
                _indexes[^1] = -1;
 
                _linearOffset = _tensor._start - (!_tensor.IsEmpty ? _tensor.Strides[^1] : 0);
                _itemsEnumerated = 0;
            }
 
            //
            // IDisposable
            //
 
            readonly void IDisposable.Dispose() { }
 
            //
            // IEnumerator
            //
 
            readonly object? IEnumerator.Current => Current;
 
            //
            // IEnumerator<T>
            //
 
            readonly T IEnumerator<T>.Current => Current;
        }
    }
}