|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers.Binary;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Formats.Nrbf.Utils;
using System.Text;
using System.Runtime.Serialization;
using System.Runtime.InteropServices;
namespace System.Formats.Nrbf;
/// <summary>
/// Provides stateless methods for decoding .NET Remoting Binary Format (NRBF) encoded data.
/// </summary>
public static class NrbfDecoder
{
private static UTF8Encoding ThrowOnInvalidUtf8Encoding { get; } = new(false, throwOnInvalidBytes: true);
// The header consists of:
// - a byte that describes the record type (SerializationRecordType.SerializedStreamHeader)
// - four 32 bit integers:
// - root Id (every value except of 0 is valid)
// - header Id (value is ignored)
// - major version, it has to be equal 1.
// - minor version, it has to be equal 0.
private static ReadOnlySpan<byte> HeaderSuffix => [1, 0, 0, 0, 0, 0, 0, 0];
/// <summary>
/// Checks if the given buffer starts with the <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/a7e578d3-400a-4249-9424-7529d10d1b3c">NRBF payload header</see>.
/// </summary>
/// <param name="bytes">The buffer to inspect.</param>
/// <returns><see langword="true" /> if the buffer starts with the NRBF payload header; otherwise, <see langword="false" />.</returns>
public static bool StartsWithPayloadHeader(ReadOnlySpan<byte> bytes)
=> bytes.Length >= SerializedStreamHeaderRecord.Size
&& bytes[0] == (byte)SerializationRecordType.SerializedStreamHeader
&& bytes.Slice(SerializedStreamHeaderRecord.Size - HeaderSuffix.Length, HeaderSuffix.Length).SequenceEqual(HeaderSuffix);
/// <summary>
/// Checks if the given stream starts with the <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/a7e578d3-400a-4249-9424-7529d10d1b3c">NRBF payload header</see>.
/// </summary>
/// <param name="stream">The stream to inspect. The stream must be both readable and seekable.</param>
/// <returns><see langword="true" /> if the stream starts with the NRBF payload header; otherwise, <see langword="false" />.</returns>
/// <exception cref="ArgumentNullException"><paramref name="stream" /> is <see langword="null" />.</exception>
/// <exception cref="NotSupportedException">The stream does not support reading or seeking.</exception>
/// <exception cref="ObjectDisposedException">The stream was closed.</exception>
/// <exception cref="IOException">An I/O error occurred.</exception>
/// <remarks>When this method returns, <paramref name="stream" /> is restored to its original position.</remarks>
public static bool StartsWithPayloadHeader(Stream stream)
{
#if NET
ArgumentNullException.ThrowIfNull(stream);
#else
if (stream is null)
{
throw new ArgumentNullException(nameof(stream));
}
#endif
if (!stream.CanSeek)
{
throw new ArgumentException(SR.Argument_NonSeekableStream, nameof(stream));
}
long beginning = stream.Position;
if (stream.Length - beginning <= SerializedStreamHeaderRecord.Size)
{
return false;
}
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
{
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
{
stream.Position = beginning;
return false;
}
offset += read;
}
bool result = StartsWithPayloadHeader(buffer);
stream.Position = beginning;
return result;
}
/// <summary>
/// Decodes the provided NRBF payload.
/// </summary>
/// <param name="payload">The NRBF payload.</param>
/// <param name="options">Options to control behavior during parsing.</param>
/// <param name="leaveOpen">
/// <see langword="true" /> to leave <paramref name="payload"/> payload open
/// after the reading is finished; otherwise, <see langword="false" />.
/// </param>
/// <returns>A <see cref="SerializationRecord"/> that represents the root object.
/// It can be either <see cref="PrimitiveTypeRecord{T}"/>,
/// a <see cref="ClassRecord"/>, or an <see cref="ArrayRecord"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="payload"/> is <see langword="null" />.</exception>
/// <exception cref="ArgumentException"><paramref name="payload"/> does not support reading or is already closed.</exception>
/// <exception cref="SerializationException">Reading from <paramref name="payload"/> encountered invalid NRBF data.</exception>
/// <exception cref="IOException">An I/O error occurred.</exception>
/// <exception cref="NotSupportedException">
/// Reading from <paramref name="payload"/> encountered unsupported records,
/// for example, arrays with non-zero offset or unsupported record types
/// (<see cref="SerializationRecordType.ClassWithMembers"/>, <see cref="SerializationRecordType.SystemClassWithMembers"/>,
/// <see cref="SerializationRecordType.MethodCall"/>, or <see cref="SerializationRecordType.MethodReturn"/>).
/// </exception>
/// <exception cref="DecoderFallbackException">Reading from <paramref name="payload"/>
/// encountered an invalid UTF8 sequence.</exception>
/// <exception cref="EndOfStreamException">The end of the stream was reached before reading <see cref="SerializationRecordType.MessageEnd"/> record.</exception>
public static SerializationRecord Decode(Stream payload, PayloadOptions? options = default, bool leaveOpen = false)
=> Decode(payload, out _, options, leaveOpen);
/// <param name="payload">The NRBF payload.</param>
/// <param name="recordMap">
/// When this method returns, contains a mapping of <see cref="SerializationRecordId" /> to the associated serialization record.
/// This parameter is treated as uninitialized.
/// </param>
/// <param name="options">An object that describes optional <see cref="PayloadOptions"/> parameters to use.</param>
/// <param name="leaveOpen">
/// <see langword="true" /> to leave <paramref name="payload"/> payload open
/// after the reading is finished; otherwise, <see langword="false" />.
/// </param>
/// <inheritdoc cref="Decode(Stream, PayloadOptions?, bool)"/>
public static SerializationRecord Decode(Stream payload, out IReadOnlyDictionary<SerializationRecordId, SerializationRecord> recordMap, PayloadOptions? options = default, bool leaveOpen = false)
{
#if NET
ArgumentNullException.ThrowIfNull(payload);
#else
if (payload is null)
{
throw new ArgumentNullException(nameof(payload));
}
#endif
using BinaryReader reader = new(payload, ThrowOnInvalidUtf8Encoding, leaveOpen: leaveOpen);
try
{
return Decode(reader, options ?? new(), out recordMap);
}
catch (FormatException) // can be thrown by various BinaryReader methods
{
throw new SerializationException(SR.Serialization_InvalidFormat);
}
}
/// <summary>
/// Decodes the provided NRBF payload that is expected to contain an instance of any class (or struct) that is not an <see cref="Array"/> or a primitive type.
/// </summary>
/// <returns>A <see cref="ClassRecord"/> that represents the root object.</returns>
/// <inheritdoc cref="Decode(Stream, PayloadOptions?, bool)"/>
public static ClassRecord DecodeClassRecord(Stream payload, PayloadOptions? options = default, bool leaveOpen = false)
=> (ClassRecord)Decode(payload, options, leaveOpen);
private static SerializationRecord Decode(BinaryReader reader, PayloadOptions options, out IReadOnlyDictionary<SerializationRecordId, SerializationRecord> readOnlyRecordMap)
{
Stack<NextInfo> readStack = new();
RecordMap recordMap = new();
// Everything has to start with a header
var header = (SerializedStreamHeaderRecord)DecodeNext(reader, recordMap, AllowedRecordTypes.SerializedStreamHeader, options, out _);
// and can be followed by any Object, BinaryLibrary and a MessageEnd.
const AllowedRecordTypes Allowed = AllowedRecordTypes.AnyObject
| AllowedRecordTypes.BinaryLibrary | AllowedRecordTypes.MessageEnd;
SerializationRecordType recordType;
SerializationRecord nextRecord;
do
{
while (readStack.Count > 0)
{
NextInfo nextInfo = readStack.Pop();
if (nextInfo.Allowed != AllowedRecordTypes.None)
{
// Decode the next Record
do
{
nextRecord = DecodeNext(reader, recordMap, nextInfo.Allowed, options, out _);
// BinaryLibrary often precedes class records.
// It has been already added to the RecordMap and it must not be added
// to the array record, so simply read next record.
// It's possible to read multiple BinaryLibraryRecord in a row, hence the loop.
}
while (nextRecord is BinaryLibraryRecord);
// Handle it:
// - add to the parent records list,
// - push next info if there are remaining nested records to read.
nextInfo.Parent.HandleNextRecord(nextRecord, nextInfo);
// Push on the top of the stack the first nested record of last read record,
// so it gets read as next record.
PushFirstNestedRecordInfo(nextRecord, readStack);
}
else
{
object value = reader.ReadPrimitiveValue(nextInfo.PrimitiveType);
nextInfo.Parent.HandleNextValue(value, nextInfo);
}
}
nextRecord = DecodeNext(reader, recordMap, Allowed, options, out recordType);
PushFirstNestedRecordInfo(nextRecord, readStack);
}
while (recordType != SerializationRecordType.MessageEnd);
readOnlyRecordMap = recordMap;
return recordMap.GetRootRecord(header);
}
private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap recordMap,
AllowedRecordTypes allowed, PayloadOptions options, out SerializationRecordType recordType)
{
recordType = reader.ReadSerializationRecordType(allowed);
SerializationRecord record = recordType switch
{
SerializationRecordType.ArraySingleObject => ArraySingleObjectRecord.Decode(reader),
SerializationRecordType.ArraySinglePrimitive => DecodeArraySinglePrimitiveRecord(reader),
SerializationRecordType.ArraySingleString => ArraySingleStringRecord.Decode(reader),
SerializationRecordType.BinaryArray => BinaryArrayRecord.Decode(reader, recordMap, options),
SerializationRecordType.BinaryLibrary => BinaryLibraryRecord.Decode(reader, options),
SerializationRecordType.BinaryObjectString => BinaryObjectStringRecord.Decode(reader),
SerializationRecordType.ClassWithId => ClassWithIdRecord.Decode(reader, recordMap),
SerializationRecordType.ClassWithMembersAndTypes => ClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
SerializationRecordType.MemberPrimitiveTyped => DecodeMemberPrimitiveTypedRecord(reader),
SerializationRecordType.MemberReference => MemberReferenceRecord.Decode(reader, recordMap),
SerializationRecordType.MessageEnd => MessageEndRecord.Singleton,
SerializationRecordType.ObjectNull => ObjectNullRecord.Instance,
SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader),
SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader),
SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader),
SerializationRecordType.SystemClassWithMembersAndTypes => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
_ => throw new InvalidOperationException()
};
recordMap.Add(record);
return record;
}
private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader reader)
{
PrimitiveType primitiveType = reader.ReadPrimitiveType();
return primitiveType switch
{
PrimitiveType.Boolean => new MemberPrimitiveTypedRecord<bool>(reader.ReadBoolean()),
PrimitiveType.Byte => new MemberPrimitiveTypedRecord<byte>(reader.ReadByte()),
PrimitiveType.SByte => new MemberPrimitiveTypedRecord<sbyte>(reader.ReadSByte()),
PrimitiveType.Char => new MemberPrimitiveTypedRecord<char>(reader.ParseChar()),
PrimitiveType.Int16 => new MemberPrimitiveTypedRecord<short>(reader.ReadInt16()),
PrimitiveType.UInt16 => new MemberPrimitiveTypedRecord<ushort>(reader.ReadUInt16()),
PrimitiveType.Int32 => new MemberPrimitiveTypedRecord<int>(reader.ReadInt32()),
PrimitiveType.UInt32 => new MemberPrimitiveTypedRecord<uint>(reader.ReadUInt32()),
PrimitiveType.Int64 => new MemberPrimitiveTypedRecord<long>(reader.ReadInt64()),
PrimitiveType.UInt64 => new MemberPrimitiveTypedRecord<ulong>(reader.ReadUInt64()),
PrimitiveType.Single => new MemberPrimitiveTypedRecord<float>(reader.ReadSingle()),
PrimitiveType.Double => new MemberPrimitiveTypedRecord<double>(reader.ReadDouble()),
PrimitiveType.Decimal => new MemberPrimitiveTypedRecord<decimal>(reader.ParseDecimal()),
PrimitiveType.DateTime => new MemberPrimitiveTypedRecord<DateTime>(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64())),
PrimitiveType.TimeSpan => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
_ => throw new InvalidOperationException()
};
}
private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader)
{
ArrayInfo info = ArrayInfo.Decode(reader);
PrimitiveType primitiveType = reader.ReadPrimitiveType();
return primitiveType switch
{
PrimitiveType.Boolean => Decode<bool>(info, reader),
PrimitiveType.Byte => Decode<byte>(info, reader),
PrimitiveType.SByte => Decode<sbyte>(info, reader),
PrimitiveType.Char => Decode<char>(info, reader),
PrimitiveType.Int16 => Decode<short>(info, reader),
PrimitiveType.UInt16 => Decode<ushort>(info, reader),
PrimitiveType.Int32 => Decode<int>(info, reader),
PrimitiveType.UInt32 => Decode<uint>(info, reader),
PrimitiveType.Int64 => Decode<long>(info, reader),
PrimitiveType.UInt64 => Decode<ulong>(info, reader),
PrimitiveType.Single => Decode<float>(info, reader),
PrimitiveType.Double => Decode<double>(info, reader),
PrimitiveType.Decimal => Decode<decimal>(info, reader),
PrimitiveType.DateTime => Decode<DateTime>(info, reader),
PrimitiveType.TimeSpan => Decode<TimeSpan>(info, reader),
_ => throw new InvalidOperationException()
};
static SerializationRecord Decode<T>(ArrayInfo info, BinaryReader reader) where T : unmanaged
=> new ArraySinglePrimitiveRecord<T>(info, ArraySinglePrimitiveRecord<T>.DecodePrimitiveTypes(reader, info.GetSZArrayLength()));
}
/// <summary>
/// This method is responsible for pushing only the FIRST read info
/// of the NESTED record into the <paramref name="readStack"/>.
/// It's not pushing all of them, because it could be used as a vector of attack.
/// Example: BinaryArrayRecord with Array.MaxLength length,
/// where first item turns out to be <see cref="ObjectNullMultipleRecord"/>
/// that provides Array.MaxLength nulls.
/// </summary>
private static void PushFirstNestedRecordInfo(SerializationRecord record, Stack<NextInfo> readStack)
{
if (record is ClassRecord classRecord)
{
if (classRecord.ExpectedValuesCount > 0)
{
(AllowedRecordTypes allowed, PrimitiveType primitiveType) = classRecord.GetNextAllowedRecordType();
readStack.Push(new(allowed, classRecord, readStack, primitiveType));
}
}
else if (record is ArrayRecord arrayRecord && arrayRecord.ValuesToRead > 0)
{
(AllowedRecordTypes allowed, PrimitiveType primitiveType) = arrayRecord.GetAllowedRecordType();
readStack.Push(new(allowed, arrayRecord, readStack, primitiveType));
}
}
}
|