File: System\Formats\Nrbf\SZArrayOfRecords.cs
Web Access
Project: src\src\libraries\System.Formats.Nrbf\src\System.Formats.Nrbf.csproj (System.Formats.Nrbf)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;
 
namespace System.Formats.Nrbf;
 
// This library tries to minimize the number of concepts the users need to learn to use it.
// Since SZArrays are most common, it provides an SZArrayRecord<T> abstraction.
// Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord.
// The goal of this class is to let the users use SZArrayRecord<SerializationRecord> abstraction.
internal sealed class SZArrayOfRecords : SZArrayRecord<SerializationRecord>
{
    private TypeName? _typeName;
 
    internal SZArrayOfRecords(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
        : base(arrayInfo)
    {
        MemberTypeInfo = memberTypeInfo;
        Records = [];
    }
 
    public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
 
    internal List<SerializationRecord> Records { get; }
 
    private MemberTypeInfo MemberTypeInfo { get; }
 
    public override TypeName TypeName
        => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);
 
    /// <inheritdoc/>
    public override SerializationRecord?[] GetArray(bool allowNulls = true)
        => (SerializationRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));
 
    private SerializationRecord?[] ToArray(bool allowNulls)
    {
        SerializationRecord?[] result = new SerializationRecord?[Length];
 
        int resultIndex = 0;
        foreach (SerializationRecord record in Records)
        {
            SerializationRecord actual = record is MemberReferenceRecord referenceRecord
                ? referenceRecord.GetReferencedRecord()
                : record;
 
            if (actual is not NullsRecord nullsRecord)
            {
                result[resultIndex++] = actual;
            }
            else
            {
                if (!allowNulls)
                {
                    ThrowHelper.ThrowArrayContainedNulls();
                }
 
                int nullCount = nullsRecord.NullCount;
                Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
                do
                {
                    result[resultIndex++] = null;
                    nullCount--;
                }
                while (nullCount > 0);
            }
        }
 
        Debug.Assert(resultIndex == result.Length, "We should have traversed the entirety of the newly created array.");
 
        return result;
    }
 
    private protected override void AddValue(object value) => Records.Add((SerializationRecord)value);
 
    internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
    {
        (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0);
 
        if (allowed != AllowedRecordTypes.None)
        {
            // It's an array, it can also contain multiple nulls
            return (allowed | AllowedRecordTypes.Nulls, primitiveType);
        }
 
        return (allowed, primitiveType);
    }
}