File: System\Numerics\Tensors\netcore\TensorDimensionSpan_1.cs
Web Access
Project: src\src\libraries\System.Numerics.Tensors\src\System.Numerics.Tensors.csproj (System.Numerics.Tensors)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
 
namespace System.Numerics.Tensors
{
    /// <summary>Represents the slices that exist within a dimension of a tensor span.</summary>
    /// <typeparam name="T">The type of the elements within the tensor span.</typeparam>
    public readonly ref struct TensorDimensionSpan<T>
    {
        private readonly TensorSpan<T> _tensor;
        private readonly nint _length;
        private readonly int _dimension;
        private readonly TensorShape _sliceShape;
 
        internal TensorDimensionSpan(TensorSpan<T> tensor, int dimension)
        {
            if ((uint)dimension >= tensor.Rank)
            {
                ThrowHelper.ThrowArgumentOutOfRangeException();
            }
            dimension += 1;
 
            _tensor = tensor;
            _length = TensorPrimitives.Product(tensor.Lengths[..dimension]);
            _dimension = dimension;
            _sliceShape = TensorShape.Create((dimension != tensor.Rank) ? tensor.Lengths[dimension..] : [1], tensor.Strides[dimension..]);
        }
 
        /// <summary>Gets the length of the tensor dimension span.</summary>
        public nint Length => _length;
 
        /// <summary>Gets the tensor span representing a slice of the tracked dimension using the specified index.</summary>
        /// <param name="index">The index of the tensor span slice to retrieve within the tracked dimension.</param>
        /// <returns>The tensor span representing a slice of the tracked dimension using <paramref name="index" />.</returns>
        public TensorSpan<T> this[nint index]
        {
            get
            {
                if ((nuint)index >= (nuint)_length)
                {
                    ThrowHelper.ThrowArgumentOutOfRangeException();
                }
 
                nint linearOffset = _tensor._shape.GetLinearOffset(index, _dimension);
                return new TensorSpan<T>(ref Unsafe.Add(ref _tensor._reference, linearOffset), _sliceShape);
            }
        }
 
        /// <summary>Gets an enumerator for the readonly tensor dimension span.</summary>
        public Enumerator GetEnumerator() => new Enumerator(this);
 
        /// <summary>Enumerates the spans of a tensor dimension span.</summary>
        public ref struct Enumerator
#if NET9_0_OR_GREATER
            : IEnumerator<TensorSpan<T>>
#endif
        {
            private readonly TensorDimensionSpan<T> _span;
            private nint _index;
 
            internal Enumerator(TensorDimensionSpan<T> span)
            {
                _span = span;
                _index = -1;
            }
 
            /// <summary>Gets the span at the current position of the enumerator.</summary>
            public readonly TensorSpan<T> Current => _span[_index];
 
            /// <summary>Advances the enumerator to the next element of the tensor span.</summary>
            public bool MoveNext()
            {
                nint index = _index + 1;
 
                if (index < _span.Length)
                {
                    _index = index;
                    return true;
                }
                return false;
            }
 
            /// <summary>Sets the enumerator to its initial position, which is before the first element in the tensor span.</summary>
            public void Reset()
            {
                _index = -1;
            }
 
#if NET9_0_OR_GREATER
            //
            // IDisposable
            //
 
            void IDisposable.Dispose() { }
 
            //
            // IEnumerator
            //
 
            readonly object? IEnumerator.Current => throw new NotSupportedException();
#endif
        }
    }
}