|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
using System.Buffers.Binary;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Reflection.Metadata;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Formats.Nrbf.Utils;
namespace System.Formats.Nrbf;
/// <summary>
/// Represents a single-dimensional array of a primitive type.
/// </summary>
/// <remarks>
/// ArraySinglePrimitive records are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/3a50a305-5f32-48a1-a42a-c34054db310b">[MS-NRBF] 2.4.3.3</see>.
/// </remarks>
internal sealed class ArraySinglePrimitiveRecord<T> : SZArrayRecord<T>
where T : unmanaged
{
internal ArraySinglePrimitiveRecord(ArrayInfo arrayInfo, IReadOnlyList<T> values) : base(arrayInfo)
{
Values = values;
ValuesToRead = 0; // there is nothing to read anymore
}
public override SerializationRecordType RecordType => SerializationRecordType.ArraySinglePrimitive;
/// <inheritdoc />
public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.GetPrimitiveType<T>());
internal IReadOnlyList<T> Values { get; }
/// <inheritdoc/>
public override T[] GetArray(bool allowNulls = true)
=> (T[])(_arrayNullsNotAllowed ??= (Values is T[] array ? array : Values.ToArray()));
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException();
private protected override void AddValue(object value) => throw new InvalidOperationException();
internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int count)
{
if (count == 0)
{
return Array.Empty<T>(); // Empty arrays are allowed.
}
// For decimals, the input is provided as strings, so we can't compute the required size up-front.
if (typeof(T) == typeof(decimal))
{
return (List<T>)(object)DecodeDecimals(reader, count);
}
// char[] has a unique representation in NRBF streams. Typical strings are transcoded
// to UTF-8 and prefixed with the number of bytes in the UTF-8 representation. char[]
// is also serialized as UTF-8, but it is instead prefixed with the number of chars
// in the UTF-16 representation, not the number of bytes in the UTF-8 representation.
// This number doesn't directly precede the UTF-8 contents in the NRBF stream; it's
// instead contained within the ArrayInfo structure (passed to this method as the
// 'count' argument).
//
// The practical consequence of this is that we don't actually know how many UTF-8
// bytes we need to consume in order to ensure we've read 'count' chars. We know that
// an n-length UTF-16 string turns into somewhere between [n .. 3n] UTF-8 bytes.
// The best we can do is that when reading an n-element char[], we'll ensure that
// there are at least n bytes remaining in the input stream. We'll still need to
// account for that even with this check, we might hit EOF before fully populating
// the char[]. But from a safety perspective, it does appropriately limit our
// allocations to be proportional to the amount of data present in the input stream,
// which is a sufficient defense against DoS.
// We can't assume DateTime as represented by the runtime is 8 bytes.
// The only assumption we can make is that it's 8 bytes on the wire.
int sizeOfT = typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan)
? 8
: typeof(T) != typeof(char)
? Unsafe.SizeOf<T>()
: 1;
long requiredBytes = (long)count * sizeOfT;
bool? isDataAvailable = reader.IsDataAvailable(requiredBytes);
if (!isDataAvailable.HasValue)
{
return DecodeFromNonSeekableStream(reader, count);
}
if (!isDataAvailable.Value)
{
// We are sure there is not enough data.
ThrowHelper.ThrowEndOfStreamException();
}
if (typeof(T) == typeof(byte))
{
return (T[])(object)reader.ReadBytes(count);
}
else if (typeof(T) == typeof(char))
{
return (T[])(object)reader.ParseChars(count);
}
else if (typeof(T) == typeof(TimeSpan) || typeof(T) == typeof(DateTime))
{
return DecodeTime(reader, count);
}
// It's safe to pre-allocate, as we have ensured there is enough bytes in the stream.
T[] result = new T[count];
// MemoryMarshal.AsBytes can fail for inputs that need more than int.MaxValue bytes.
// To avoid OverflowException, we read the data in chunks.
int MaxChunkLength =
#if !DEBUG
int.MaxValue / sizeOfT;
#else
// Let's use a different value for non-release builds to ensure this code path
// is covered with tests without the need of decoding enormous payloads.
8_000;
#endif
#if !NET
byte[] rented = ArrayPool<byte>.Shared.Rent((int)Math.Min(requiredBytes, 256_000));
#endif
Span<T> valuesToRead = result.AsSpan();
while (!valuesToRead.IsEmpty)
{
int sliceSize = Math.Min(valuesToRead.Length, MaxChunkLength);
Span<byte> resultAsBytes = MemoryMarshal.AsBytes<T>(valuesToRead.Slice(0, sliceSize));
#if NET
reader.BaseStream.ReadExactly(resultAsBytes);
#else
while (!resultAsBytes.IsEmpty)
{
int bytesRead = reader.Read(rented, 0, Math.Min(resultAsBytes.Length, rented.Length));
if (bytesRead <= 0)
{
ArrayPool<byte>.Shared.Return(rented);
ThrowHelper.ThrowEndOfStreamException();
}
rented.AsSpan(0, bytesRead).CopyTo(resultAsBytes);
resultAsBytes = resultAsBytes.Slice(bytesRead);
}
#endif
valuesToRead = valuesToRead.Slice(sliceSize);
}
#if !NET
ArrayPool<byte>.Shared.Return(rented);
#endif
if (!BitConverter.IsLittleEndian)
{
if (typeof(T) == typeof(short) || typeof(T) == typeof(ushort))
{
Span<short> span = MemoryMarshal.Cast<T, short>(result.AsSpan());
#if NET
BinaryPrimitives.ReverseEndianness(span, span);
#else
for (int i = 0; i < span.Length; i++)
{
span[i] = BinaryPrimitives.ReverseEndianness(span[i]);
}
#endif
}
else if (typeof(T) == typeof(int) || typeof(T) == typeof(uint) || typeof(T) == typeof(float))
{
Span<int> span = MemoryMarshal.Cast<T, int>(result.AsSpan());
#if NET
BinaryPrimitives.ReverseEndianness(span, span);
#else
for (int i = 0; i < span.Length; i++)
{
span[i] = BinaryPrimitives.ReverseEndianness(span[i]);
}
#endif
}
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double))
{
Span<long> span = MemoryMarshal.Cast<T, long>(result.AsSpan());
#if NET
BinaryPrimitives.ReverseEndianness(span, span);
#else
for (int i = 0; i < span.Length; i++)
{
span[i] = BinaryPrimitives.ReverseEndianness(span[i]);
}
#endif
}
}
if (typeof(T) == typeof(bool))
{
// See DontCastBytesToBooleans test to see what could go wrong.
bool[] booleans = (bool[])(object)result;
Span<byte> resultAsBytes = MemoryMarshal.AsBytes<T>(result.AsSpan());
for (int i = 0; i < booleans.Length; i++)
{
// We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this.
if (resultAsBytes[i] != 0) // it can be any byte different than 0
{
booleans[i] = true; // set it to 1 in explicit way
}
}
}
return result;
}
private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
{
List<decimal> values = new();
for (int i = 0; i < count; i++)
{
values.Add(reader.ParseDecimal());
}
return values;
}
private static T[] DecodeTime(BinaryReader reader, int count)
{
T[] values = new T[count];
for (int i = 0; i < values.Length; i++)
{
if (typeof(T) == typeof(DateTime))
{
values[i] = (T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64());
}
else if (typeof(T) == typeof(TimeSpan))
{
values[i] = (T)(object)new TimeSpan(reader.ReadInt64());
}
else
{
throw new InvalidOperationException();
}
}
return values;
}
private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int count)
{
// The count arg could originate from untrusted input, so we shouldn't
// pass it as-is to the ctor's capacity arg. We'll instead rely on
// List<T>.Add's O(1) amortization to keep the entire loop O(count).
List<T> values = new List<T>(Math.Min(count, 4));
for (int i = 0; i < count; i++)
{
if (typeof(T) == typeof(byte))
{
values.Add((T)(object)reader.ReadByte());
}
else if (typeof(T) == typeof(bool))
{
values.Add((T)(object)reader.ReadBoolean());
}
else if (typeof(T) == typeof(sbyte))
{
values.Add((T)(object)reader.ReadSByte());
}
else if (typeof(T) == typeof(char))
{
values.Add((T)(object)reader.ParseChar());
}
else if (typeof(T) == typeof(short))
{
values.Add((T)(object)reader.ReadInt16());
}
else if (typeof(T) == typeof(ushort))
{
values.Add((T)(object)reader.ReadUInt16());
}
else if (typeof(T) == typeof(int))
{
values.Add((T)(object)reader.ReadInt32());
}
else if (typeof(T) == typeof(uint))
{
values.Add((T)(object)reader.ReadUInt32());
}
else if (typeof(T) == typeof(long))
{
values.Add((T)(object)reader.ReadInt64());
}
else if (typeof(T) == typeof(ulong))
{
values.Add((T)(object)reader.ReadUInt64());
}
else if (typeof(T) == typeof(float))
{
values.Add((T)(object)reader.ReadSingle());
}
else if (typeof(T) == typeof(double))
{
values.Add((T)(object)reader.ReadDouble());
}
else if (typeof(T) == typeof(DateTime))
{
values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64()));
}
else if (typeof(T) == typeof(TimeSpan))
{
values.Add((T)(object)new TimeSpan(reader.ReadInt64()));
}
else
{
throw new InvalidOperationException();
}
}
return values;
}
}
|