File: System\Formats\Nrbf\RectangularArrayRecord.cs
Web Access
Project: src\src\libraries\System.Formats.Nrbf\src\System.Formats.Nrbf.csproj (System.Formats.Nrbf)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Reflection.Metadata;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;
 
namespace System.Formats.Nrbf;
 
internal sealed class RectangularArrayRecord : ArrayRecord
{
    private readonly int[] _lengths;
    private readonly List<object> _values;
    private TypeName? _typeName;
 
    private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
        MemberTypeInfo memberTypeInfo, int[] lengths, bool canPreAllocate) : base(arrayInfo)
    {
        ElementType = elementType;
        MemberTypeInfo = memberTypeInfo;
        _lengths = lengths;
 
        // ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength
        _values = new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()));
    }
 
    public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
 
    public override ReadOnlySpan<int> Lengths => _lengths.AsSpan();
 
    public override TypeName TypeName
        => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);
 
    private Type ElementType { get; }
 
    private MemberTypeInfo MemberTypeInfo { get; }
 
    [RequiresDynamicCode("May call Array.CreateInstance() and Type.MakeArrayType().")]
    private protected override Array Deserialize(Type arrayType, bool allowNulls)
    {
        // We can not deserialize non-primitive types.
        // This method returns arrays of ClassRecord for arrays of complex types.
        Array result =
#if NET9_0_OR_GREATER
            ElementType == typeof(ClassRecord)
                ? Array.CreateInstance(ElementType, _lengths)
                : Array.CreateInstanceFromArrayType(arrayType, _lengths);
#else
            Array.CreateInstance(ElementType, _lengths);
#endif
 
#if !NET8_0_OR_GREATER
        int[] indices = new int[_lengths.Length];
        nuint numElementsWritten = 0; // only for debugging; not used in release builds
 
        foreach (object value in _values)
        {
            result.SetValue(GetActualValue(value), indices);
            numElementsWritten++;
 
            int dimension = indices.Length - 1;
            while (dimension >= 0)
            {
                indices[dimension]++;
                if (indices[dimension] < Lengths[dimension])
                {
                    break;
                }
                indices[dimension] = 0;
                dimension--;
            }
 
            if (dimension < 0)
            {
                break;
            }
        }
 
        Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection.");
        Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array.");
 
        return result;
#else
        // Idea from Array.CoreCLR that maps an array of int indices into
        // an internal flat index.
        if (ElementType.IsValueType)
        {
            if (ElementType == typeof(bool)) CopyTo<bool>(_values, result);
            else if (ElementType == typeof(byte)) CopyTo<byte>(_values, result);
            else if (ElementType == typeof(sbyte)) CopyTo<sbyte>(_values, result);
            else if (ElementType == typeof(short)) CopyTo<short>(_values, result);
            else if (ElementType == typeof(ushort)) CopyTo<ushort>(_values, result);
            else if (ElementType == typeof(char)) CopyTo<char>(_values, result);
            else if (ElementType == typeof(int)) CopyTo<int>(_values, result);
            else if (ElementType == typeof(float)) CopyTo<float>(_values, result);
            else if (ElementType == typeof(long)) CopyTo<long>(_values, result);
            else if (ElementType == typeof(ulong)) CopyTo<ulong>(_values, result);
            else if (ElementType == typeof(double)) CopyTo<double>(_values, result);
            else if (ElementType == typeof(TimeSpan)) CopyTo<TimeSpan>(_values, result);
            else if (ElementType == typeof(DateTime)) CopyTo<DateTime>(_values, result);
            else if (ElementType == typeof(decimal)) CopyTo<decimal>(_values, result);
            else throw new InvalidOperationException();
        }
        else
        {
            CopyTo<object>(_values, result);
        }
 
        return result;
 
        static void CopyTo<T>(List<object> list, Array array)
        {
            ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array);
            ref T firstElementRef = ref Unsafe.As<byte, T>(ref arrayDataRef);
            nuint flattenedIndex = 0;
            foreach (object value in list)
            {
                ref T targetElement = ref Unsafe.Add(ref firstElementRef, flattenedIndex);
                targetElement = (T)GetActualValue(value)!;
                flattenedIndex++;
            }
 
            Debug.Assert(flattenedIndex == (ulong)array.LongLength, "We should have traversed the entirety of the array.");
        }
#endif
    }
 
    private protected override void AddValue(object value) => _values.Add(value);
 
    internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
    {
        (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0);
 
        if (allowed != AllowedRecordTypes.None)
        {
            // It's an array, it can also contain multiple nulls
            return (allowed | AllowedRecordTypes.Nulls, primitiveType);
        }
 
        return (allowed, primitiveType);
    }
 
    internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arrayInfo,
        MemberTypeInfo memberTypeInfo, int[] lengths)
    {
        BinaryType binaryType = memberTypeInfo.Infos[0].BinaryType;
        Type elementType = binaryType switch
        {
            BinaryType.Primitive => MapPrimitive((PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!),
            BinaryType.PrimitiveArray => MapPrimitiveArray((PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!),
            BinaryType.String => typeof(string),
            BinaryType.Object => typeof(object),
            _ => typeof(ClassRecord)
        };
 
        bool canPreAllocate = false;
        if (binaryType == BinaryType.Primitive)
        {
            int sizeOfSingleValue = (PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo! switch
            {
                PrimitiveType.Boolean => sizeof(bool),
                PrimitiveType.Byte => sizeof(byte),
                PrimitiveType.SByte => sizeof(sbyte),
                PrimitiveType.Char => sizeof(byte), // it's UTF8 (see comment below)
                PrimitiveType.Int16 => sizeof(short),
                PrimitiveType.UInt16 => sizeof(ushort),
                PrimitiveType.Int32 => sizeof(int),
                PrimitiveType.UInt32 => sizeof(uint),
                PrimitiveType.Single => sizeof(float),
                PrimitiveType.Int64 => sizeof(long),
                PrimitiveType.UInt64 => sizeof(ulong),
                PrimitiveType.Double => sizeof(double),
                PrimitiveType.TimeSpan => sizeof(ulong),
                PrimitiveType.DateTime => sizeof(ulong),
                PrimitiveType.Decimal => -1, // represented as variable-length string
                _ => throw new InvalidOperationException()
            };
 
            if (sizeOfSingleValue > 0)
            {
                // NRBF encodes rectangular char[,,,...] by converting each standalone UTF-16 code point into
                // its UTF-8 encoding. This means that surrogate code points (including adjacent surrogate
                // pairs) occurring within a char[,,,...] cannot be encoded by NRBF. BinaryReader will detect
                // that they're ill-formed and reject them on read.
                //
                // Per the comment in ArraySinglePrimitiveRecord.DecodePrimitiveTypes, we'll assume best-case
                // encoding where 1 UTF-16 char encodes as a single UTF-8 byte, even though this might lead
                // to encountering an EOF if we realize later that we actually need to read more bytes in
                // order to fully populate the char[,,,...] array. Any such allocation is still linearly
                // proportional to the length of the incoming payload, so it's not a DoS vector.
                // The multiplication below is guaranteed not to overflow because FlattenedLength is bounded
                // to <= Array.MaxLength (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8.
                Debug.Assert(arrayInfo.FlattenedLength >= 0 && arrayInfo.FlattenedLength <= long.MaxValue / sizeOfSingleValue);
 
                long size = arrayInfo.FlattenedLength * sizeOfSingleValue;
                bool? isDataAvailable = reader.IsDataAvailable(size);
                if (isDataAvailable.HasValue)
                {
                    if (!isDataAvailable.Value)
                    {
                        ThrowHelper.ThrowEndOfStreamException();
                    }
 
                    canPreAllocate = true;
                }
            }
        }
 
        return new RectangularArrayRecord(elementType, arrayInfo, memberTypeInfo, lengths, canPreAllocate);
    }
 
    private static Type MapPrimitive(PrimitiveType primitiveType)
        => primitiveType switch
        {
            PrimitiveType.Boolean => typeof(bool),
            PrimitiveType.Byte => typeof(byte),
            PrimitiveType.Char => typeof(char),
            PrimitiveType.Decimal => typeof(decimal),
            PrimitiveType.Double => typeof(double),
            PrimitiveType.Int16 => typeof(short),
            PrimitiveType.Int32 => typeof(int),
            PrimitiveType.Int64 => typeof(long),
            PrimitiveType.SByte => typeof(sbyte),
            PrimitiveType.Single => typeof(float),
            PrimitiveType.TimeSpan => typeof(TimeSpan),
            PrimitiveType.DateTime => typeof(DateTime),
            PrimitiveType.UInt16 => typeof(ushort),
            PrimitiveType.UInt32 => typeof(uint),
            PrimitiveType.UInt64 => typeof(ulong),
            _ => throw new InvalidOperationException()
        };
 
    private static Type MapPrimitiveArray(PrimitiveType primitiveType)
        => primitiveType switch
        {
            PrimitiveType.Boolean => typeof(bool[]),
            PrimitiveType.Byte => typeof(byte[]),
            PrimitiveType.Char => typeof(char[]),
            PrimitiveType.Decimal => typeof(decimal[]),
            PrimitiveType.Double => typeof(double[]),
            PrimitiveType.Int16 => typeof(short[]),
            PrimitiveType.Int32 => typeof(int[]),
            PrimitiveType.Int64 => typeof(long[]),
            PrimitiveType.SByte => typeof(sbyte[]),
            PrimitiveType.Single => typeof(float[]),
            PrimitiveType.TimeSpan => typeof(TimeSpan[]),
            PrimitiveType.DateTime => typeof(DateTime[]),
            PrimitiveType.UInt16 => typeof(ushort[]),
            PrimitiveType.UInt32 => typeof(uint[]),
            PrimitiveType.UInt64 => typeof(ulong[]),
            _ => throw new InvalidOperationException()
        };
 
    private static object? GetActualValue(object value)
        => value is SerializationRecord serializationRecord
            ? serializationRecord.GetValue()
            : value; // it must be a primitive type
}