File: System\Windows\Nrbf\SerializationRecordExtensions.cs
Web Access
Project: src\src\Microsoft.DotNet.Wpf\src\PresentationCore\PresentationCore.csproj (PresentationCore)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
 
using System.Collections;
using System.Drawing;
using System.Diagnostics.CodeAnalysis;
 
#nullable enable
namespace System.Formats.Nrbf
{
    internal static class SerializationRecordExtensions
    {
        private delegate bool TryGetDelegate(SerializationRecord record, [NotNullWhen(true)] out object? value);
 
        private static bool TryGet(TryGetDelegate get, SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            try
            {
                return get(record, out value);
            }
            catch (Exception ex) when (ex is KeyNotFoundException or InvalidCastException)
            {
                // This should only really happen with corrupted data.
                Debug.Fail(ex.Message);
                value = default;
                return false;
            }
        }
 
        /// <summary>
        ///  Tries to get this object as a <see cref="PointF"/>.
        /// </summary>
        public static bool TryGetPointF(this SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            return TryGet(Get, record, out value);
 
            static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
            {
                value = default;
 
                if (record is not ClassRecord classInfo
                    || !classInfo.TypeNameMatches(typeof(PointF))
                    || !classInfo.HasMember("x")
                    || !classInfo.HasMember("y"))
                {
                    return false;
                }
 
                value = new PointF(classInfo.GetSingle("x"), classInfo.GetSingle("y"));
 
                return true;
            }
        }
 
        /// <summary>
        ///  Tries to get this object as a <see cref="RectangleF"/>.
        /// </summary>
        public static bool TryGetRectangleF(this SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            return TryGet(Get, record, out value);
 
            static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
            {
                value = default;
 
                if (record is not ClassRecord classInfo
                    || !classInfo.TypeNameMatches(typeof(RectangleF))
                    || !classInfo.HasMember("x")
                    || !classInfo.HasMember("y")
                    || !classInfo.HasMember("width")
                    || !classInfo.HasMember("height"))
                {
                    return false;
                }
 
                value = new RectangleF(
                    classInfo.GetSingle("x"),
                    classInfo.GetSingle("y"),
                    classInfo.GetSingle("width"),
                    classInfo.GetSingle("height"));
 
                return true;
            }
        }
 
        /// <summary>
        ///  Trys to get this object as a primitive type or string.
        /// </summary>
        /// <returns><see langword="true"/> if this represented a primitive type or string.</returns>
        public static bool TryGetPrimitiveType(this SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            return TryGet(Get, record, out value);
 
            static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
            {
                if (record.RecordType is SerializationRecordType.BinaryObjectString)
                {
                    value = ((PrimitiveTypeRecord<string>)record).Value;
                    return true;
                }
                else if (record.RecordType is SerializationRecordType.MemberPrimitiveTyped)
                {
                    value = ((PrimitiveTypeRecord)record).Value;
                    return true;
                }
 
                value = null;
                return false;
            }
        }
 
        /// <summary>
        ///  Trys to get this object as a <see cref="List{T}"/> of <see cref="PrimitiveType"/>.
        /// </summary>
        public static bool TryGetPrimitiveList(this SerializationRecord record, [NotNullWhen(true)] out object? list)
        {
            return TryGet(Get, record, out list);
 
            static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? list)
            {
                list = null;
 
                if (record is not ClassRecord classInfo
                    || !classInfo.HasMember("_items")
                    || !classInfo.HasMember("_size")
                    || classInfo.GetRawValue("_size") is not int size
                    || !classInfo.TypeName.IsConstructedGenericType
                    || classInfo.TypeName.GetGenericTypeDefinition().Name != typeof(List<>).Name
                    || classInfo.TypeName.GetGenericArguments().Length != 1
                    || classInfo.GetRawValue("_items") is not ArrayRecord arrayRecord
                    || !IsPrimitiveArrayRecord(arrayRecord))
                {
                    return false;
                }
 
                // BinaryFormatter serializes the entire backing array, so we need to trim it down to the size of the list.
                list = arrayRecord switch
                {
                    SZArrayRecord<string> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<bool> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<byte> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<sbyte> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<char> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<short> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<ushort> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<int> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<uint> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<long> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<ulong> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<float> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<double> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<decimal> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<TimeSpan> ar => ar.GetArray().CreateTrimmedList(size),
                    SZArrayRecord<DateTime> ar => ar.GetArray().CreateTrimmedList(size),
                    _ => throw new InvalidOperationException()
                };
 
                return true;
            }
            }
 
        /// <summary>
        ///  Tries to get this object as a <see cref="ArrayList"/> of <see cref="PrimitiveType"/> values.
        /// </summary>
        public static bool TryGetPrimitiveArrayList(this SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            return TryGet(Get, record, out value);
 
            static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
            {
                value = null;
 
                if (record is not ClassRecord classInfo
                    || !classInfo.TypeNameMatches(typeof(ArrayList))
                    || !classInfo.HasMember("_items")
                    || !classInfo.HasMember("_size")
                    || classInfo.GetRawValue("_size") is not int size
                    || classInfo.GetRawValue("_items") is not SZArrayRecord<SerializationRecord> arrayRecord
                    || size > arrayRecord.Length)
                {
                    return false;
                }
 
                ArrayList arrayList = new(size);
                SerializationRecord?[] array = arrayRecord.GetArray();
                for (int i = 0; i < size; i++)
                {
                    SerializationRecord? elementRecord = array[i];
                    if (elementRecord is null)
                    {
                        arrayList.Add(null);
                    }
                    else if (elementRecord is PrimitiveTypeRecord primitiveTypeRecord)
                    {
                        arrayList.Add(primitiveTypeRecord.Value);
                    }
                    else
                    {
                        // It was a complex type (represented as a ClassRecord or an ArrayRecord)
                        return false;
                    }
                }
 
                value = arrayList;
                return true;
            }
        }
 
        /// <summary>
        ///  Tries to get this object as an <see cref="Array"/> of primitive types.
        /// </summary>
        public static bool TryGetPrimitiveArray(this SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            return TryGet(Get, record, out value);
 
            static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
            {
                if (!IsPrimitiveArrayRecord(record))
                {
                    value = null;
                    return false;
                }
 
                value = record switch
                {
                    SZArrayRecord<string> ar => ar.GetArray(),
                    SZArrayRecord<bool> ar => ar.GetArray(),
                    SZArrayRecord<byte> ar => ar.GetArray(),
                    SZArrayRecord<sbyte> ar => ar.GetArray(),
                    SZArrayRecord<char> ar => ar.GetArray(),
                    SZArrayRecord<short> ar => ar.GetArray(),
                    SZArrayRecord<ushort> ar => ar.GetArray(),
                    SZArrayRecord<int> ar => ar.GetArray(),
                    SZArrayRecord<uint> ar => ar.GetArray(),
                    SZArrayRecord<long> ar => ar.GetArray(),
                    SZArrayRecord<ulong> ar => ar.GetArray(),
                    SZArrayRecord<float> ar => ar.GetArray(),
                    SZArrayRecord<double> ar => ar.GetArray(),
                    SZArrayRecord<decimal> ar => ar.GetArray(),
                    SZArrayRecord<TimeSpan> ar => ar.GetArray(),
                    SZArrayRecord<DateTime> ar => ar.GetArray(),
                    _ => throw new InvalidOperationException()
                };
 
                return value is not null;
            }
        }
 
        /// <summary>
        ///  Trys to get this object as a binary recordted <see cref="Hashtable"/> of <see cref="PrimitiveType"/> keys and values.
        /// </summary>
        public static bool TryGetPrimitiveHashtable(this SerializationRecord record, [NotNullWhen(true)] out object? hashtable)
        {
            return TryGet(Get, record, out hashtable);
 
            static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? hashtable)
            {
                hashtable = null;
 
                if (record.RecordType != SerializationRecordType.SystemClassWithMembersAndTypes
                    || record is not ClassRecord classInfo
                    || !classInfo.TypeNameMatches(typeof(Hashtable))
                    || !classInfo.HasMember("Keys")
                    || !classInfo.HasMember("Values")
                    // Note that hashtables with custom comparers and/or hash code providers will have non null Comparer
                    || classInfo.GetSerializationRecord("Comparer") is not null
                    || classInfo.GetSerializationRecord("Keys") is not SZArrayRecord<SerializationRecord?> keysRecord
                    || classInfo.GetSerializationRecord("Values") is not SZArrayRecord<SerializationRecord?> valuesRecord
                    || keysRecord.Length != valuesRecord.Length)
                {
                    return false;
                }
 
                Hashtable temp = new(keysRecord.Length);
                SerializationRecord?[] keys = keysRecord.GetArray();
                SerializationRecord?[] values = valuesRecord.GetArray();
                for (int i = 0; i < keys.Length; i++)
                {
                    SerializationRecord? key = keys[i];
                    SerializationRecord? value = values[i];
 
                    if (key is null || key is not PrimitiveTypeRecord primitiveKey)
                    {
                        return false;
                    }
 
                    if (value is null)
                    {
                        temp[primitiveKey.Value] = null; // null values are allowed
                    }
                    else if (value is PrimitiveTypeRecord primitiveValue)
                    {
                        temp[primitiveKey.Value] = primitiveValue.Value;
                    }
                    else
                    {
                        // It was a complex type (represented as a ClassRecord or an ArrayRecord)
                        return false;
                    }
                }
 
                hashtable = temp;
                return true;
            }
        }
 
        /// <summary>
        ///  Trys to get this object as a binary recordted <see cref="NotSupportedException"/>.
        /// </summary>
        public static bool TryGetNotSupportedException(
            this SerializationRecord record,
            out object? exception)
        {
            return TryGet(Get, record, out exception);
 
            static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? exception)
            {
                exception = null;
 
                if (record is not ClassRecord classInfo
                    || !classInfo.TypeNameMatches(typeof(NotSupportedException)))
                {
                    return false;
                }
 
                exception = new NotSupportedException(classInfo.GetString("Message"));
                return true;
            }
        }
 
        /// <summary>
        ///  Try to get a supported .NET type object (not WinForms).
        /// </summary>
        public static bool TryGetFrameworkObject(
            this SerializationRecord record,
            [NotNullWhen(true)] out object? value)
            => record.TryGetPrimitiveType(out value)
                || record.TryGetPrimitiveList(out value)
                || record.TryGetPrimitiveArray(out value)
                || record.TryGetPrimitiveArrayList(out value)
                || record.TryGetPrimitiveHashtable(out value)
                || record.TryGetRectangleF(out value)
                || record.TryGetPointF(out value)
                || record.TryGetNotSupportedException(out value);
 
        private static bool IsPrimitiveArrayRecord(SerializationRecord serializationRecord)
            => serializationRecord.RecordType is SerializationRecordType.ArraySingleString or SerializationRecordType.ArraySinglePrimitive;
 
        /// <summary>
        ///  Creates a list trimmed to the given count.
        /// </summary>
        /// <remarks>
        ///  <para>
        ///   This is an optimized implementation that avoids iterating over the entire list when possible.
        ///  </para>
        /// </remarks>
        internal static List<T> CreateTrimmedList<T>(this IReadOnlyList<T> readOnlyList, int count)
        {
            ArgumentOutOfRangeException.ThrowIfLessThan(readOnlyList.Count, count, nameof(count));
 
            // List<T> will use ICollection<T>.CopyTo if it's available, which is faster than iterating over the list.
            // If we just have an array this can be done easily with ArraySegment<T>.
            if (readOnlyList is T[] array)
            {
                return new List<T>(new ArraySegment<T>(array, 0, count));
            }
 
            // Fall back to just setting the count (by removing).
            List<T> list = new(readOnlyList);
            list.RemoveRange(count, list.Count - count);
            return list;
        }
    }
}