File: System\Formats\Nrbf\Utils\TypeNameHelpers.cs
Web Access
Project: src\src\libraries\System.Formats.Nrbf\src\System.Formats.Nrbf.csproj (System.Formats.Nrbf)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Diagnostics;
using System.Reflection.Metadata;
using System.Runtime.CompilerServices;
using System.Runtime.Serialization;
using System.Text;
 
namespace System.Formats.Nrbf.Utils;
 
internal static class TypeNameHelpers
{
    // PrimitiveType does not define Object, IntPtr or UIntPtr.
    internal const PrimitiveType StringPrimitiveType = (PrimitiveType)18;
    internal const PrimitiveType ObjectPrimitiveType = (PrimitiveType)19;
    internal const PrimitiveType IntPtrPrimitiveType = (PrimitiveType)20;
    internal const PrimitiveType UIntPtrPrimitiveType = (PrimitiveType)21;
    private static readonly TypeName?[] s_primitiveTypeNames = new TypeName?[(int)UIntPtrPrimitiveType + 1];
    private static readonly TypeName?[] s_primitiveSZArrayTypeNames = new TypeName?[(int)UIntPtrPrimitiveType + 1];
    private static AssemblyNameInfo? s_coreLibAssemblyName;
 
    internal static TypeName GetPrimitiveTypeName(PrimitiveType primitiveType)
    {
        TypeName? typeName = s_primitiveTypeNames[(int)primitiveType];
        if (typeName is null)
        {
            string fullName = primitiveType switch
            {
                PrimitiveType.Boolean => "System.Boolean",
                PrimitiveType.Byte => "System.Byte",
                PrimitiveType.SByte => "System.SByte",
                PrimitiveType.Char => "System.Char",
                PrimitiveType.Int16 => "System.Int16",
                PrimitiveType.UInt16 => "System.UInt16",
                PrimitiveType.Int32 => "System.Int32",
                PrimitiveType.UInt32 => "System.UInt32",
                PrimitiveType.Int64 => "System.Int64",
                PrimitiveType.UInt64 => "System.UInt64",
                PrimitiveType.Single => "System.Single",
                PrimitiveType.Double => "System.Double",
                PrimitiveType.Decimal => "System.Decimal",
                PrimitiveType.TimeSpan => "System.TimeSpan",
                PrimitiveType.DateTime => "System.DateTime",
                StringPrimitiveType => "System.String",
                ObjectPrimitiveType => "System.Object",
                IntPtrPrimitiveType => "System.IntPtr",
                UIntPtrPrimitiveType => "System.UIntPtr",
                _ => throw new InvalidOperationException()
            };
 
            s_primitiveTypeNames[(int)primitiveType] = typeName = TypeName.Parse(fullName.AsSpan()).WithCoreLibAssemblyName();
        }
        return typeName;
    }
 
    internal static TypeName GetPrimitiveSZArrayTypeName(PrimitiveType primitiveType)
    {
        TypeName? typeName = s_primitiveSZArrayTypeNames[(int)primitiveType];
        if (typeName is null)
        {
            s_primitiveSZArrayTypeNames[(int)primitiveType] = typeName = GetPrimitiveTypeName(primitiveType).MakeSZArrayTypeName();
        }
        return typeName;
    }
 
    internal static PrimitiveType GetPrimitiveType<T>()
    {
        if (typeof(T) == typeof(bool))
            return PrimitiveType.Boolean;
        else if (typeof(T) == typeof(byte))
            return PrimitiveType.Byte;
        else if (typeof(T) == typeof(sbyte))
            return PrimitiveType.SByte;
        else if (typeof(T) == typeof(char))
            return PrimitiveType.Char;
        else if (typeof(T) == typeof(short))
            return PrimitiveType.Int16;
        else if (typeof(T) == typeof(ushort))
            return PrimitiveType.UInt16;
        else if (typeof(T) == typeof(int))
            return PrimitiveType.Int32;
        else if (typeof(T) == typeof(uint))
            return PrimitiveType.UInt32;
        else if (typeof(T) == typeof(long))
            return PrimitiveType.Int64;
        else if (typeof(T) == typeof(ulong))
            return PrimitiveType.UInt64;
        else if (typeof(T) == typeof(float))
            return PrimitiveType.Single;
        else if (typeof(T) == typeof(double))
            return PrimitiveType.Double;
        else if (typeof(T) == typeof(decimal))
            return PrimitiveType.Decimal;
        else if (typeof(T) == typeof(DateTime))
            return PrimitiveType.DateTime;
        else if (typeof(T) == typeof(TimeSpan))
            return PrimitiveType.TimeSpan;
        else if (typeof(T) == typeof(string))
            return StringPrimitiveType;
        else if (typeof(T) == typeof(IntPtr))
            return IntPtrPrimitiveType;
        else if (typeof(T) == typeof(UIntPtr))
            return UIntPtrPrimitiveType;
        else
            throw new InvalidOperationException();
    }
 
    internal static TypeName ParseNonSystemClassRecordTypeName(this string rawName, BinaryLibraryRecord libraryRecord, PayloadOptions payloadOptions)
    {
        if (libraryRecord.LibraryName is not null)
        {
            return ParseWithoutAssemblyName(rawName, payloadOptions).With(libraryRecord.LibraryName);
        }
 
        Debug.Assert(payloadOptions.UndoTruncatedTypeNames);
        Debug.Assert(libraryRecord.RawLibraryName is not null);
 
        // This is potentially a DoS vector, as somebody could submit:
        // [1] BinaryLibraryRecord = <really long string>
        // [2] ClassRecord (lib = [1])
        // [3] ClassRecord (lib = [1])
        // ...
        // [n] ClassRecord (lib = [1])
        //
        // Which means somebody submits a payload of length O(long + n) and tricks us into
        // performing O(long * n) work. For this reason, we have marked the UndoTruncatedTypeNames
        // property as "keep this disabled unless you trust the input."
 
        // Combining type and library allows us for handling truncated generic type names that may be present in resources.
        ArraySegment<char> assemblyQualifiedName = RentAssemblyQualifiedName(rawName, libraryRecord.RawLibraryName);
        TypeName.TryParse(assemblyQualifiedName.AsSpan(), out TypeName? typeName, payloadOptions.TypeNameParseOptions);
        ArrayPool<char>.Shared.Return(assemblyQualifiedName.Array!);
 
        if (typeName is null)
        {
            throw new SerializationException(SR.Serialization_InvalidTypeOrAssemblyName);
        }
 
        if (typeName.AssemblyName is null)
        {
            // Sample invalid input that could lead us here:
            // TypeName: System.Collections.Generic.List`1[[System.String
            // LibraryName: 'mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089]]
            // Since the flag is ON, we know it's mangling and we provide missing information.
            typeName = typeName.WithCoreLibAssemblyName();
        }
 
        return typeName;
    }
 
    internal static TypeName ParseSystemRecordTypeName(this string rawName, PayloadOptions payloadOptions)
        => ParseWithoutAssemblyName(rawName, payloadOptions)
                .WithCoreLibAssemblyName(); // We know it's a System Record, so we set the LibraryName to CoreLib
 
    internal static TypeName WithCoreLibAssemblyName(this TypeName systemType)
        => systemType.With(s_coreLibAssemblyName ??= AssemblyNameInfo.Parse("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089".AsSpan()));
 
    private static TypeName With(this TypeName typeName, AssemblyNameInfo assemblyName)
    {
        // This is a recursive method over potentially hostile TypeName arguments.
        // We assume the complexity of the TypeName arg was appropriately bounded.
        // See comment in TypeName.FullName property getter for more info.
 
        if (!typeName.IsSimple)
        {
            if (typeName.IsArray)
            {
                TypeName newElementType = typeName.GetElementType().With(assemblyName);
 
                return typeName.IsSZArray
                    ? newElementType.MakeSZArrayTypeName()
                    : newElementType.MakeArrayTypeName(typeName.GetArrayRank());
            }
            else if (typeName.IsConstructedGenericType)
            {
                TypeName newGenericTypeDefinition = typeName.GetGenericTypeDefinition().With(assemblyName);
 
                // We don't change the assembly name of generic arguments on purpose.
                return newGenericTypeDefinition.MakeGenericTypeName(typeName.GetGenericArguments());
            }
            else
            {
                // BinaryFormatter can not serialize pointers or references.
                ThrowHelper.ThrowInvalidTypeName();
            }
        }
 
        return typeName.WithAssemblyName(assemblyName);
    }
 
    private static TypeName ParseWithoutAssemblyName(string rawName, PayloadOptions payloadOptions)
    {
        if (!TypeName.TryParse(rawName.AsSpan(), out TypeName? typeName, payloadOptions.TypeNameParseOptions)
            || typeName.AssemblyName is not null) // the type and library names should be provided separately
        {
            throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeName, rawName));
        }
 
        return typeName;
    }
 
    // Complexity is O(typeName.Length + libraryName.Length)
    private static ArraySegment<char> RentAssemblyQualifiedName(string typeName, string libraryName)
    {
        int length = typeName.Length + 1 + libraryName.Length;
 
        char[] rented = ArrayPool<char>.Shared.Rent(length);
 
        typeName.AsSpan().CopyTo(rented);
        rented[typeName.Length] = ',';
        libraryName.AsSpan().CopyTo(rented.AsSpan(typeName.Length + 1));
 
        return new ArraySegment<char>(rented, 0, length);
    }
}