File: System\Runtime\Serialization\BinaryFormat\ArraySinglePrimitiveRecord.cs
Web Access
Project: src\src\libraries\System.Runtime.Serialization.BinaryFormat\src\System.Runtime.Serialization.BinaryFormat.csproj (System.Runtime.Serialization.BinaryFormat)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Buffers.Binary;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Reflection.Metadata;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Serialization.BinaryFormat.Utils;
 
namespace System.Runtime.Serialization.BinaryFormat;
 
/// <summary>
/// Represents a single-dimensional array of a primitive type.
/// </summary>
/// <remarks>
/// ArraySinglePrimitive records are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/3a50a305-5f32-48a1-a42a-c34054db310b">[MS-NRBF] 2.4.3.3</see>.
/// </remarks>
internal sealed class ArraySinglePrimitiveRecord<T> : ArrayRecord<T>
    where T : unmanaged
{
    private static TypeName? s_elementTypeName;
 
    internal ArraySinglePrimitiveRecord(ArrayInfo arrayInfo, IReadOnlyList<T> values) : base(arrayInfo)
    {
        Values = values;
        ValuesToRead = 0; // there is nothing to read anymore
    }
 
    public override RecordType RecordType => RecordType.ArraySinglePrimitive;
 
    public override TypeName ElementTypeName
        => s_elementTypeName ??= TypeName.Parse(typeof(T).FullName.AsSpan()).WithAssemblyName(typeof(T).GetAssemblyNameIncludingTypeForwards());
 
    internal IReadOnlyList<T> Values { get; }
 
    public override bool IsTypeNameMatching(Type type) => typeof(T[]) == type;
 
    internal override bool IsElementType(Type typeElement) => typeElement == typeof(T);
 
    /// <inheritdoc/>
    public override T[] GetArray(bool allowNulls = true)
        => (T[])(_arrayNullsNotAllowed ??= (Values is T[] array ? array : Values.ToArray()));
 
    internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
    {
        Debug.Fail("GetAllowedRecordType should never be called on ArraySinglePrimitiveRecord");
        throw new InvalidOperationException();
    }
 
    private protected override void AddValue(object value)
    {
        Debug.Fail("AddValue should never be called on ArraySinglePrimitiveRecord");
        throw new InvalidOperationException();
    }
 
    internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int count)
    {
        // For decimals, the input is provided as strings, so we can't compute the required size up-front.
        if (typeof(T) == typeof(decimal))
        {
            return (List<T>)(object)DecodeDecimals(reader, count);
        }
 
        long requiredBytes = count;
        if (typeof(T) != typeof(char)) // the input is UTF8
        {
            requiredBytes *= Unsafe.SizeOf<T>();
        }
 
        bool? isDataAvailable = reader.IsDataAvailable(requiredBytes);
        if (!isDataAvailable.HasValue)
        {
            return DecodeFromNonSeekableStream(reader, count);
        }
 
        if (!isDataAvailable.Value)
        {
            // We are sure there is not enough data.
            ThrowHelper.ThrowEndOfStreamException();
        }
 
        if (typeof(T) == typeof(byte))
        {
            return (T[])(object)reader.ReadBytes(count);
        }
        else if (typeof(T) == typeof(char))
        {
            return (T[])(object)reader.ReadChars(count);
        }
 
        // It's safe to pre-allocate, as we have ensured there is enough bytes in the stream.
        T[] result = new T[count];
        Span<byte> resultAsBytes = MemoryMarshal.AsBytes<T>(result);
#if NET
        reader.BaseStream.ReadExactly(resultAsBytes);
#else
        byte[] bytes = ArrayPool<byte>.Shared.Rent(Math.Min(count * Unsafe.SizeOf<T>(), 256_000));
 
        while (!resultAsBytes.IsEmpty)
        {
            int bytesRead = reader.Read(bytes, 0, Math.Min(resultAsBytes.Length, bytes.Length));
            if (bytesRead <= 0)
            {
                ArrayPool<byte>.Shared.Return(bytes);
                ThrowHelper.ThrowEndOfStreamException();
            }
 
            bytes.AsSpan(0, bytesRead).CopyTo(resultAsBytes);
            resultAsBytes = resultAsBytes.Slice(bytesRead);
        }
 
        ArrayPool<byte>.Shared.Return(bytes);
#endif
 
        if (!BitConverter.IsLittleEndian)
        {
            if (typeof(T) == typeof(short) || typeof(T) == typeof(ushort))
            {
                Span<short> span = MemoryMarshal.Cast<T, short>(result);
#if NET
                BinaryPrimitives.ReverseEndianness(span, span);
#else
                for (int i = 0; i < span.Length; i++)
                {
                    span[i] = BinaryPrimitives.ReverseEndianness(span[i]);
                }
#endif
            }
            else if (typeof(T) == typeof(int) || typeof(T) == typeof(uint) || typeof(T) == typeof(float))
            {
                Span<int> span = MemoryMarshal.Cast<T, int>(result);
#if NET
                BinaryPrimitives.ReverseEndianness(span, span);
#else
                for (int i = 0; i < span.Length; i++)
                {
                    span[i] = BinaryPrimitives.ReverseEndianness(span[i]);
                }
#endif
            }
            else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double)
                  || typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
            {
                Span<long> span = MemoryMarshal.Cast<T, long>(result);
#if NET
                BinaryPrimitives.ReverseEndianness(span, span);
#else
                for (int i = 0; i < span.Length; i++)
                {
                    span[i] = BinaryPrimitives.ReverseEndianness(span[i]);
                }
#endif
            }
        }
 
        return result;
    }
 
    private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
    {
        List<decimal> values = new();
#if NET
        Span<byte> buffer = stackalloc byte[256];
        for (int i = 0; i < count; i++)
        {
            int stringLength = reader.Read7BitEncodedInt();
            if (!(stringLength > 0 && stringLength <= buffer.Length))
            {
                ThrowHelper.ThrowInvalidValue(stringLength);
            }
 
            reader.BaseStream.ReadExactly(buffer.Slice(0, stringLength));
 
            values.Add(decimal.Parse(buffer.Slice(0, stringLength), CultureInfo.InvariantCulture));
        }
#else
        for (int i = 0; i < count; i++)
        {
            values.Add(decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture));
        }
#endif
        return values;
    }
 
    private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int count)
    {
        List<T> values = new List<T>(Math.Min(count, 4));
        for (int i = 0; i < count; i++)
        {
            if (typeof(T) == typeof(byte))
            {
                values.Add((T)(object)reader.ReadByte());
            }
            else if (typeof(T) == typeof(bool))
            {
                values.Add((T)(object)reader.ReadBoolean());
            }
            else if (typeof(T) == typeof(sbyte))
            {
                values.Add((T)(object)reader.ReadSByte());
            }
            else if (typeof(T) == typeof(char))
            {
                values.Add((T)(object)reader.ReadChar());
            }
            else if (typeof(T) == typeof(short))
            {
                values.Add((T)(object)reader.ReadInt16());
            }
            else if (typeof(T) == typeof(ushort))
            {
                values.Add((T)(object)reader.ReadUInt16());
            }
            else if (typeof(T) == typeof(int))
            {
                values.Add((T)(object)reader.ReadInt32());
            }
            else if (typeof(T) == typeof(uint))
            {
                values.Add((T)(object)reader.ReadUInt32());
            }
            else if (typeof(T) == typeof(long))
            {
                values.Add((T)(object)reader.ReadInt64());
            }
            else if (typeof(T) == typeof(ulong))
            {
                values.Add((T)(object)reader.ReadUInt64());
            }
            else if (typeof(T) == typeof(float))
            {
                values.Add((T)(object)reader.ReadSingle());
            }
            else if (typeof(T) == typeof(double))
            {
                values.Add((T)(object)reader.ReadDouble());
            }
            else if (typeof(T) == typeof(DateTime))
            {
                values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadInt64()));
            }
            else
            {
                Debug.Assert(typeof(T) == typeof(TimeSpan));
 
                values.Add((T)(object)new TimeSpan(reader.ReadInt64()));
            }
        }
 
        return values;
    }
}