File: System\Formats\Nrbf\BinaryArrayRecord.cs
Web Access
Project: src\src\libraries\System.Formats.Nrbf\src\System.Formats.Nrbf.csproj (System.Formats.Nrbf)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;
 
namespace System.Formats.Nrbf;
 
/// <summary>
/// Represents an array other than single dimensional array of primitive types or <see cref="object"/>.
/// </summary>
/// <remarks>
/// BinaryArray records are described in <see href="https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/9c62c928-db4e-43ca-aeba-146256ef67c2">[MS-NRBF] 2.4.3.1</see>.
/// </remarks>
internal sealed class BinaryArrayRecord : ArrayRecord
{
    private static HashSet<Type> PrimitiveTypes { get; } =
    [
        typeof(bool), typeof(char), typeof(byte), typeof(sbyte),
        typeof(short), typeof(ushort), typeof(int), typeof(uint),
        typeof(long), typeof(ulong), typeof(IntPtr), typeof(UIntPtr),
        typeof(float), typeof(double), typeof(decimal), typeof(DateTime),
        typeof(TimeSpan), typeof(string), typeof(object)
    ];
 
    private TypeName? _typeName;
    private long _totalElementsCount;
 
    private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
        : base(arrayInfo)
    {
        MemberTypeInfo = memberTypeInfo;
        Values = [];
        // We need to parse all elements of the jagged array to obtain total elements count.
        _totalElementsCount = -1;
    }
 
    public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
 
    /// <inheritdoc/>
    public override ReadOnlySpan<int> Lengths => new int[1] { Length };
 
    /// <inheritdoc/>
    public override long FlattenedLength
    {
        get
        {
            if (_totalElementsCount < 0)
            {
                _totalElementsCount = IsJagged
                    ? GetJaggedArrayFlattenedLength(this)
                    : ArrayInfo.FlattenedLength;
            }
 
            return _totalElementsCount;
        }
    }
 
    public override TypeName TypeName
        => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);
 
    private int Length => ArrayInfo.GetSZArrayLength();
 
    private MemberTypeInfo MemberTypeInfo { get; }
 
    private List<object> Values { get; }
 
    [RequiresDynamicCode("May call Array.CreateInstance() and Type.MakeArrayType().")]
    private protected override Array Deserialize(Type arrayType, bool allowNulls)
    {
        // We can not deserialize non-primitive types.
        // This method returns arrays of ClassRecord for arrays of complex types.
        Type elementType = MapElementType(arrayType, out bool isClassRecord);
        Type actualElementType = arrayType.GetElementType()!;
        Array array =
#if NET9_0_OR_GREATER
            isClassRecord
                ? Array.CreateInstance(elementType, Length)
                : Array.CreateInstanceFromArrayType(arrayType, Length);
#else
            Array.CreateInstance(elementType, Length);
#endif
 
        int resultIndex = 0;
        foreach (object value in Values)
        {
            object item = value is MemberReferenceRecord referenceRecord
                ? referenceRecord.GetReferencedRecord()
                : value;
 
            if (item is not SerializationRecord record)
            {
                array.SetValue(item, resultIndex++);
                continue;
            }
 
            switch (record.RecordType)
            {
                case SerializationRecordType.BinaryArray:
                case SerializationRecordType.ArraySinglePrimitive:
                case SerializationRecordType.ArraySingleObject:
                case SerializationRecordType.ArraySingleString:
 
                    // Recursion depth is bounded by the depth of arrayType, which is
                    // a trustworthy Type instance. Don't need to worry about stack overflow.
 
                    ArrayRecord nestedArrayRecord = (ArrayRecord)record;
                    Array nestedArray = nestedArrayRecord.GetArray(actualElementType, allowNulls);
                    array.SetValue(nestedArray, resultIndex++);
                    break;
                case SerializationRecordType.ObjectNull:
                case SerializationRecordType.ObjectNullMultiple256:
                case SerializationRecordType.ObjectNullMultiple:
                    if (!allowNulls)
                    {
                        ThrowHelper.ThrowArrayContainedNulls();
                    }
 
                    int nullCount = ((NullsRecord)item).NullCount;
                    Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
                    do
                    {
                        array.SetValue(null, resultIndex++);
                        nullCount--;
                    }
                    while (nullCount > 0);
                    break;
                default:
                    array.SetValue(record.GetValue(), resultIndex++);
                    break;
            }
        }
 
        Debug.Assert(resultIndex == array.Length, "We should have traversed the entirety of the newly created array.");
 
        return array;
    }
 
    internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options)
    {
        SerializationRecordId objectId = SerializationRecordId.Decode(reader);
        BinaryArrayType arrayType = reader.ReadArrayType();
        int rank = reader.ReadInt32();
 
        bool isRectangular = arrayType is BinaryArrayType.Rectangular;
 
        // It is an arbitrary limit in the current CoreCLR type loader.
        // Don't change this value without reviewing the loop a few lines below.
        const int MaxSupportedArrayRank = 32;
 
        if (rank < 1 || rank > MaxSupportedArrayRank
            || (rank != 1 && !isRectangular)
            || (rank == 1 && isRectangular))
        {
            ThrowHelper.ThrowInvalidValue(rank);
        }
 
        int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32
        long totalElementCount = 1; // to avoid integer overflow during the multiplication below
        for (int i = 0; i < lengths.Length; i++)
        {
            lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
            totalElementCount *= lengths[i];
 
            // n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]"
            // but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But
            // that's the same behavior that newarr and Array.CreateInstance exhibit, so at least
            // we're consistent.
 
            if (totalElementCount > ArrayInfo.MaxArrayLength)
            {
                ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
            }
        }
 
        // Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so
        // we don't need to read the NRBF stream 'LowerBounds' field here.
 
        MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap);
        ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank);
 
        if (isRectangular)
        {
            return RectangularArrayRecord.Create(reader, arrayInfo, memberTypeInfo, lengths);
        }
 
        return memberTypeInfo.ShouldBeRepresentedAsArrayOfClassRecords()
            ? new ArrayOfClassesRecord(arrayInfo, memberTypeInfo)
            : new BinaryArrayRecord(arrayInfo, memberTypeInfo);
    }
 
    private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord)
    {
        long result = 0;
        Queue<BinaryArrayRecord>? jaggedArrayRecords = null;
 
        do
        {
            if (jaggedArrayRecords is not null)
            {
                jaggedArrayRecord = jaggedArrayRecords.Dequeue();
            }
 
            Debug.Assert(jaggedArrayRecord.IsJagged);
 
            // In theory somebody could create a payload that would represent
            // a very nested array with total elements count > long.MaxValue.
            // That is why this method is using checked arithmetic.
            result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves
 
            foreach (object value in jaggedArrayRecord.Values)
            {
                if (value is not SerializationRecord record)
                {
                    continue;
                }
 
                if (record.RecordType == SerializationRecordType.MemberReference)
                {
                    record = ((MemberReferenceRecord)record).GetReferencedRecord();
                }
 
                switch (record.RecordType)
                {
                    case SerializationRecordType.ArraySinglePrimitive:
                    case SerializationRecordType.ArraySingleObject:
                    case SerializationRecordType.ArraySingleString:
                    case SerializationRecordType.BinaryArray:
                        ArrayRecord nestedArrayRecord = (ArrayRecord)record;
                        if (nestedArrayRecord.IsJagged)
                        {
                            (jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
                        }
                        else
                        {
                            // Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion,
                            // just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value.
                            result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength);
                        }
                        break;
                    default:
                        break;
                }
            }
        }
        while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0);
 
        return result;
    }
 
    private protected override void AddValue(object value) => Values.Add(value);
 
    internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
    {
        (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0);
 
        if (allowed != AllowedRecordTypes.None)
        {
            // It's an array, it can also contain multiple nulls
            return (allowed | AllowedRecordTypes.Nulls, primitiveType);
        }
 
        return (allowed, primitiveType);
    }
 
    /// <summary>
    /// Complex types must not be instantiated, but represented as ClassRecord.
    /// For arrays of primitive types like int, string and object this method returns the element type.
    /// For array of complex types, it returns ClassRecord.
    /// It takes arrays of arrays into account:
    /// - int[][] => int[]
    /// - MyClass[][][] => ClassRecord[][]
    /// </summary>
    [RequiresDynamicCode("May call Type.MakeArrayType().")]
    private static Type MapElementType(Type arrayType, out bool isClassRecord)
    {
        Type elementType = arrayType;
        int arrayNestingDepth = 0;
 
        // Loop iteration counts are bound by the nesting depth of arrayType,
        // which is a trustworthy input. No DoS concerns.
 
        while (elementType.IsArray)
        {
            elementType = elementType.GetElementType()!;
            arrayNestingDepth++;
        }
 
        if (PrimitiveTypes.Contains(elementType) || (Nullable.GetUnderlyingType(elementType) is Type nullable && PrimitiveTypes.Contains(nullable)))
        {
            isClassRecord = false;
            return arrayNestingDepth == 1 ? elementType : arrayType.GetElementType()!;
        }
 
        // Complex types are never instantiated, but represented as ClassRecord
        isClassRecord = true;
        Type complexType = typeof(ClassRecord);
        for (int i = 1; i < arrayNestingDepth; i++)
        {
            complexType = complexType.MakeArrayType();
        }
 
        return complexType;
    }
}