File: System\Text\Json\Serialization\Metadata\FSharpCoreReflectionProxy.cs
Web Access
Project: src\src\libraries\System.Text.Json\src\System.Text.Json.csproj (System.Text.Json)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Text.Json.Serialization;
 
namespace System.Text.Json.Serialization.Metadata
{
    // Recognizing types emitted by the F# compiler requires consuming APIs from the FSharp.Core runtime library.
    // Every F# application ships with a copy of FSharp.Core, however it is not available statically to System.Text.Json.
    // The following class uses reflection to access the relevant APIs required to detect the various F# types we are looking to support.
 
    /// <summary>
    /// Proxy class used to access FSharp.Core metadata and reflection APIs that are not statically available to System.Text.Json.
    /// </summary>
    internal sealed class FSharpCoreReflectionProxy
    {
        /// <summary>
        /// The various categories of F# types that System.Text.Json supports.
        /// </summary>
        public enum FSharpKind
        {
            Unrecognized,
            Option,
            ValueOption,
            List,
            Set,
            Map,
            Record,
            Union
        }
 
        // Binding a struct getter method to a delegate requires that the struct parameter is passed byref.
        public delegate TResult StructGetter<TStruct, TResult>(ref TStruct @this) where TStruct : struct;
 
        public const string FSharpCoreUnreferencedCodeMessage = "Uses Reflection to access FSharp.Core components at runtime.";
 
        private static FSharpCoreReflectionProxy? s_singletonInstance;
 
        // Every type generated by the F# compiler is annotated with the CompilationMappingAttribute
        // containing all relevant metadata required to determine its kind:
        // https://fsharp.github.io/fsharp-core-docs/reference/fsharp-core-compilationmappingattribute.html#SourceConstructFlags
        private const string CompilationMappingAttributeTypeName = "Microsoft.FSharp.Core.CompilationMappingAttribute";
        private readonly Type _compilationMappingAttributeType;
        private readonly MethodInfo? _sourceConstructFlagsGetter;
 
        private readonly Type? _fsharpOptionType;
        private readonly Type? _fsharpValueOptionType;
        private readonly Type? _fsharpListType;
        private readonly Type? _fsharpSetType;
        private readonly Type? _fsharpMapType;
 
        private readonly MethodInfo? _fsharpListCtor;
        private readonly MethodInfo? _fsharpSetCtor;
        private readonly MethodInfo? _fsharpMapCtor;
 
        // Union-related reflection members
        private readonly MethodInfo? _getUnionCases;
        private readonly MethodInfo? _preComputeUnionTagReader;
        private readonly MethodInfo? _preComputeUnionReader;
        private readonly MethodInfo? _preComputeUnionConstructor;
        private readonly MethodInfo? _unionCaseInfoNameGetter;
        private readonly MethodInfo? _unionCaseInfoTagGetter;
        private readonly MethodInfo? _unionCaseInfoGetFields;
        private readonly MethodInfo? _unionCaseInfoGetCustomAttributes;
 
        /// <summary>
        /// Checks if the provided System.Type instance is emitted by the F# compiler.
        /// If true, also initializes the proxy singleton for future by other F# types.
        /// </summary>
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public static bool IsFSharpType(Type type)
        {
            if (s_singletonInstance is null)
            {
                if (GetFSharpCoreAssembly(type) is Assembly fsharpCoreAssembly)
                {
                    // Type is F# type, initialize the singleton instance.
                    s_singletonInstance ??= new FSharpCoreReflectionProxy(fsharpCoreAssembly);
 
                    return true;
                }
 
                return false;
            }
 
            return s_singletonInstance.GetFSharpCompilationMappingAttribute(type) is not null;
        }
 
        /// <summary>
        /// Gets the singleton proxy instance; prerequires a successful IsFSharpType call for proxy initialization.
        /// </summary>
        public static FSharpCoreReflectionProxy Instance
        {
            get
            {
                Debug.Assert(s_singletonInstance is not null, "should be initialized via a successful IsFSharpType call.");
                return s_singletonInstance;
            }
        }
 
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        private FSharpCoreReflectionProxy(Assembly fsharpCoreAssembly)
        {
            Debug.Assert(fsharpCoreAssembly.GetName().Name == "FSharp.Core");
 
            Type compilationMappingAttributeType = fsharpCoreAssembly.GetType(CompilationMappingAttributeTypeName)!;
            _sourceConstructFlagsGetter = compilationMappingAttributeType.GetMethod("get_SourceConstructFlags", BindingFlags.Public | BindingFlags.Instance);
            _compilationMappingAttributeType = compilationMappingAttributeType;
 
            _fsharpOptionType = fsharpCoreAssembly.GetType("Microsoft.FSharp.Core.FSharpOption`1");
            _fsharpValueOptionType = fsharpCoreAssembly.GetType("Microsoft.FSharp.Core.FSharpValueOption`1");
            _fsharpListType = fsharpCoreAssembly.GetType("Microsoft.FSharp.Collections.FSharpList`1");
            _fsharpSetType = fsharpCoreAssembly.GetType("Microsoft.FSharp.Collections.FSharpSet`1");
            _fsharpMapType = fsharpCoreAssembly.GetType("Microsoft.FSharp.Collections.FSharpMap`2");
 
            _fsharpListCtor = fsharpCoreAssembly.GetType("Microsoft.FSharp.Collections.ListModule")?.GetMethod("OfSeq", BindingFlags.Public | BindingFlags.Static);
            _fsharpSetCtor = fsharpCoreAssembly.GetType("Microsoft.FSharp.Collections.SetModule")?.GetMethod("OfSeq", BindingFlags.Public | BindingFlags.Static);
            _fsharpMapCtor = fsharpCoreAssembly.GetType("Microsoft.FSharp.Collections.MapModule")?.GetMethod("OfSeq", BindingFlags.Public | BindingFlags.Static);
 
            // Union reflection APIs from Microsoft.FSharp.Reflection namespace
            Type? fsharpType = fsharpCoreAssembly.GetType("Microsoft.FSharp.Reflection.FSharpType");
            Type? fsharpValue = fsharpCoreAssembly.GetType("Microsoft.FSharp.Reflection.FSharpValue");
            Type? unionCaseInfoType = fsharpCoreAssembly.GetType("Microsoft.FSharp.Reflection.UnionCaseInfo");
 
            _getUnionCases = fsharpType?.GetMethod("GetUnionCases", BindingFlags.Public | BindingFlags.Static);
            _preComputeUnionTagReader = fsharpValue?.GetMethod("PreComputeUnionTagReader", BindingFlags.Public | BindingFlags.Static);
 
            if (unionCaseInfoType is not null)
            {
                _preComputeUnionReader = fsharpValue?.GetMethod("PreComputeUnionReader", BindingFlags.Public | BindingFlags.Static);
                _preComputeUnionConstructor = fsharpValue?.GetMethod("PreComputeUnionConstructor", BindingFlags.Public | BindingFlags.Static);
                _unionCaseInfoNameGetter = unionCaseInfoType.GetMethod("get_Name", BindingFlags.Public | BindingFlags.Instance);
                _unionCaseInfoTagGetter = unionCaseInfoType.GetMethod("get_Tag", BindingFlags.Public | BindingFlags.Instance);
                _unionCaseInfoGetFields = unionCaseInfoType.GetMethod("GetFields", BindingFlags.Public | BindingFlags.Instance, binder: null, Type.EmptyTypes, modifiers: null);
                _unionCaseInfoGetCustomAttributes = unionCaseInfoType.GetMethod("GetCustomAttributes", BindingFlags.Public | BindingFlags.Instance, binder: null, Type.EmptyTypes, modifiers: null);
            }
        }
 
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public FSharpKind DetectFSharpKind(Type type)
        {
            Attribute? compilationMappingAttribute = GetFSharpCompilationMappingAttribute(type);
 
            if (compilationMappingAttribute is null)
            {
                return FSharpKind.Unrecognized;
            }
 
            if (type.IsGenericType)
            {
                Type genericType = type.GetGenericTypeDefinition();
                if (genericType == _fsharpOptionType) return FSharpKind.Option;
                if (genericType == _fsharpValueOptionType) return FSharpKind.ValueOption;
                if (genericType == _fsharpListType) return FSharpKind.List;
                if (genericType == _fsharpSetType) return FSharpKind.Set;
                if (genericType == _fsharpMapType) return FSharpKind.Map;
            }
 
            return (GetSourceConstructFlags(compilationMappingAttribute) & SourceConstructFlags.KindMask) switch
            {
                SourceConstructFlags.RecordType => FSharpKind.Record,
                SourceConstructFlags.SumType => FSharpKind.Union,
                _ => FSharpKind.Unrecognized
            };
        }
 
        /// <summary>
        /// Gets the union case metadata for the specified F# discriminated union type.
        /// Returns an array of case descriptors with pre-computed delegates for tag reading,
        /// field reading, and case construction.
        /// </summary>
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public FSharpUnionCaseInfo[] GetUnionCaseInfos(Type unionType)
        {
            MethodInfo getUnionCases = EnsureMemberExists(_getUnionCases, "Microsoft.FSharp.Reflection.FSharpType.GetUnionCases(Type, BindingFlags?)");
            MethodInfo preComputeUnionReader = EnsureMemberExists(_preComputeUnionReader, "Microsoft.FSharp.Reflection.FSharpValue.PreComputeUnionReader(UnionCaseInfo, BindingFlags?)");
            MethodInfo preComputeUnionConstructor = EnsureMemberExists(_preComputeUnionConstructor, "Microsoft.FSharp.Reflection.FSharpValue.PreComputeUnionConstructor(UnionCaseInfo, BindingFlags?)");
            MethodInfo nameGetter = EnsureMemberExists(_unionCaseInfoNameGetter, "Microsoft.FSharp.Reflection.UnionCaseInfo.get_Name()");
            MethodInfo tagGetter = EnsureMemberExists(_unionCaseInfoTagGetter, "Microsoft.FSharp.Reflection.UnionCaseInfo.get_Tag()");
            MethodInfo getFields = EnsureMemberExists(_unionCaseInfoGetFields, "Microsoft.FSharp.Reflection.UnionCaseInfo.GetFields()");
            MethodInfo getCustomAttributes = EnsureMemberExists(_unionCaseInfoGetCustomAttributes, "Microsoft.FSharp.Reflection.UnionCaseInfo.GetCustomAttributes()");
 
            // FSharpType.GetUnionCases(type, bindingFlags: null) returns UnionCaseInfo[]
            object[] cases = (object[])getUnionCases.Invoke(null, new object?[] { unionType, null })!;
            var result = new FSharpUnionCaseInfo[cases.Length];
 
            for (int i = 0; i < cases.Length; i++)
            {
                object caseInfo = cases[i];
                string caseName = (string)nameGetter.Invoke(caseInfo, null)!;
                int caseTag = (int)tagGetter.Invoke(caseInfo, null)!;
                PropertyInfo[] fields = (PropertyInfo[])getFields.Invoke(caseInfo, null)!;
 
                // Read custom attributes to check for JsonPropertyNameAttribute
                object[] customAttributes = (object[])getCustomAttributes.Invoke(caseInfo, null)!;
                string? jsonPropertyName = null;
                foreach (object attr in customAttributes)
                {
                    if (attr is JsonPropertyNameAttribute jpn)
                    {
                        jsonPropertyName = jpn.Name;
                        break;
                    }
                }
 
                // PreComputeUnionReader returns FSharpFunc<obj, obj[]>
                Func<object, object[]> fieldReader = ConvertFSharpFunc<object, object[]>(
                    preComputeUnionReader.Invoke(null, new object?[] { caseInfo, null })!);
 
                // PreComputeUnionConstructor returns FSharpFunc<obj[], obj>
                Func<object[], object> constructor = ConvertFSharpFunc<object[], object>(
                    preComputeUnionConstructor.Invoke(null, new object?[] { caseInfo, null })!);
 
                result[i] = new FSharpUnionCaseInfo(caseName, caseTag, fields, jsonPropertyName, fieldReader, constructor);
            }
 
            return result;
        }
 
        /// <summary>
        /// Creates a tag reader delegate for the specified F# union type.
        /// The delegate takes a union value (boxed) and returns the integer tag.
        /// </summary>
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public Func<object, int> CreateUnionTagReader(Type unionType)
        {
            MethodInfo preComputeUnionTagReader = EnsureMemberExists(_preComputeUnionTagReader, "Microsoft.FSharp.Reflection.FSharpValue.PreComputeUnionTagReader(Type, BindingFlags?)");
 
            // PreComputeUnionTagReader returns FSharpFunc<obj, int>
            return ConvertFSharpFunc<object, int>(
                preComputeUnionTagReader.Invoke(null, new object?[] { unionType, null })!);
        }
 
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public Func<TFSharpOption, T> CreateFSharpOptionValueGetter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)] TFSharpOption, T>()
        {
            Debug.Assert(typeof(TFSharpOption).GetGenericTypeDefinition() == _fsharpOptionType);
            MethodInfo valueGetter = EnsureMemberExists(typeof(TFSharpOption).GetMethod("get_Value", BindingFlags.Public | BindingFlags.Instance), "Microsoft.FSharp.Core.FSharpOption<T>.get_Value()");
            return CreateDelegate<Func<TFSharpOption, T>>(valueGetter);
        }
 
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public Func<TElement?, TFSharpOption> CreateFSharpOptionSomeConstructor<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)] TFSharpOption, TElement>()
        {
            Debug.Assert(typeof(TFSharpOption).GetGenericTypeDefinition() == _fsharpOptionType);
            MethodInfo methodInfo = EnsureMemberExists(typeof(TFSharpOption).GetMethod("Some", BindingFlags.Public | BindingFlags.Static), "Microsoft.FSharp.Core.FSharpOption<T>.Some(T value)");
            return CreateDelegate<Func<TElement?, TFSharpOption>>(methodInfo);
        }
 
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public StructGetter<TFSharpValueOption, TElement> CreateFSharpValueOptionValueGetter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)] TFSharpValueOption, TElement>()
            where TFSharpValueOption : struct
        {
            Debug.Assert(typeof(TFSharpValueOption).GetGenericTypeDefinition() == _fsharpValueOptionType);
            MethodInfo valueGetter = EnsureMemberExists(typeof(TFSharpValueOption).GetMethod("get_Value", BindingFlags.Public | BindingFlags.Instance), "Microsoft.FSharp.Core.FSharpValueOption<T>.get_Value()");
            return CreateDelegate<StructGetter<TFSharpValueOption, TElement>>(valueGetter);
        }
 
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public Func<TElement?, TFSharpOption> CreateFSharpValueOptionSomeConstructor<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods)] TFSharpOption, TElement>()
        {
            Debug.Assert(typeof(TFSharpOption).GetGenericTypeDefinition() == _fsharpValueOptionType);
            MethodInfo methodInfo = EnsureMemberExists(typeof(TFSharpOption).GetMethod("Some", BindingFlags.Public | BindingFlags.Static), "Microsoft.FSharp.Core.FSharpValueOption<T>.ValueSome(T value)");
            return CreateDelegate<Func<TElement?, TFSharpOption>>(methodInfo);
        }
 
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public Func<IEnumerable<TElement>, TFSharpList> CreateFSharpListConstructor<TFSharpList, TElement>()
        {
            Debug.Assert(typeof(TFSharpList).GetGenericTypeDefinition() == _fsharpListType);
            return CreateDelegate<Func<IEnumerable<TElement>, TFSharpList>>(EnsureMemberExists(_fsharpListCtor, "Microsoft.FSharp.Collections.ListModule.OfSeq<T>(IEnumerable<T> source)").MakeGenericMethod(typeof(TElement)));
        }
 
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public Func<IEnumerable<TElement>, TFSharpSet> CreateFSharpSetConstructor<TFSharpSet, TElement>()
        {
            Debug.Assert(typeof(TFSharpSet).GetGenericTypeDefinition() == _fsharpSetType);
            return CreateDelegate<Func<IEnumerable<TElement>, TFSharpSet>>(EnsureMemberExists(_fsharpSetCtor, "Microsoft.FSharp.Collections.SetModule.OfSeq<T>(IEnumerable<T> source)").MakeGenericMethod(typeof(TElement)));
        }
 
        [RequiresUnreferencedCode(FSharpCoreUnreferencedCodeMessage)]
        [RequiresDynamicCode(FSharpCoreUnreferencedCodeMessage)]
        public Func<IEnumerable<Tuple<TKey, TValue>>, TFSharpMap> CreateFSharpMapConstructor<TFSharpMap, TKey, TValue>()
        {
            Debug.Assert(typeof(TFSharpMap).GetGenericTypeDefinition() == _fsharpMapType);
            return CreateDelegate<Func<IEnumerable<Tuple<TKey, TValue>>, TFSharpMap>>(EnsureMemberExists(_fsharpMapCtor, "Microsoft.FSharp.Collections.MapModule.OfSeq<TKey, TValue>(IEnumerable<Tuple<TKey, TValue>> source)").MakeGenericMethod(typeof(TKey), typeof(TValue)));
        }
 
        private Attribute? GetFSharpCompilationMappingAttribute(Type type)
        {
            object[] attributes = type.GetCustomAttributes(_compilationMappingAttributeType, inherit: true);
            return attributes.Length == 0 ? null : (Attribute)attributes[0];
        }
 
        private SourceConstructFlags GetSourceConstructFlags(Attribute compilationMappingAttribute)
            => _sourceConstructFlagsGetter is null ? SourceConstructFlags.None : (SourceConstructFlags)_sourceConstructFlagsGetter.Invoke(compilationMappingAttribute, null)!;
 
        // If the provided type is generated by the F# compiler, returns the runtime FSharp.Core assembly.
        private static Assembly? GetFSharpCoreAssembly(Type type)
        {
            foreach (Attribute attr in type.GetCustomAttributes(inherit: true))
            {
                Type attributeType = attr.GetType();
                if (attributeType.FullName == CompilationMappingAttributeTypeName)
                {
                    return attributeType.Assembly;
                }
            }
 
            return null;
        }
 
        private static TDelegate CreateDelegate<TDelegate>(MethodInfo methodInfo) where TDelegate : Delegate
            => (TDelegate)Delegate.CreateDelegate(typeof(TDelegate), methodInfo, throwOnBindFailure: true)!;
 
        // Converts an FSharpFunc<TArg, TResult> (which is not statically known) into a Func<TArg, TResult>.
        [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2075:GetMethod",
            Justification = "FSharpFunc<TArg, TResult>.Invoke is always available. Callers are marked RequiresUnreferencedCode.")]
        private static Func<TArg, TResult> ConvertFSharpFunc<TArg, TResult>(object fsharpFunc)
        {
            // FSharpFunc<TArg, TResult> has an Invoke(TArg) method.
            // Create a closed delegate to avoid MethodInfo.Invoke overhead and object[] allocation per call.
            MethodInfo invokeMethod = fsharpFunc.GetType().GetMethod("Invoke", BindingFlags.Public | BindingFlags.Instance)!;
            return (Func<TArg, TResult>)Delegate.CreateDelegate(typeof(Func<TArg, TResult>), fsharpFunc, invokeMethod);
        }
 
        private static TMemberInfo EnsureMemberExists<TMemberInfo>(TMemberInfo? memberInfo, string memberName) where TMemberInfo : MemberInfo
        {
            if (memberInfo is null)
            {
                ThrowHelper.ThrowMissingMemberException_MissingFSharpCoreMember(memberName);
            }
 
            return memberInfo;
        }
 
        // Replicates the F# source construct flags enum
        // https://fsharp.github.io/fsharp-core-docs/reference/fsharp-core-sourceconstructflags.html
        private enum SourceConstructFlags
        {
            None = 0,
            SumType = 1,
            RecordType = 2,
            ObjectType = 3,
            Field = 4,
            Exception = 5,
            Closure = 6,
            Module = 7,
            UnionCase = 8,
            Value = 9,
            KindMask = 31,
            NonPublicRepresentation = 32
        }
 
        /// <summary>
        /// Represents metadata for a single F# discriminated union case.
        /// </summary>
        internal sealed class FSharpUnionCaseInfo
        {
            public FSharpUnionCaseInfo(
                string name,
                int tag,
                PropertyInfo[] fields,
                string? jsonPropertyName,
                Func<object, object[]> fieldReader,
                Func<object[], object> constructor)
            {
                Name = name;
                Tag = tag;
                Fields = fields;
                JsonPropertyName = jsonPropertyName;
                FieldReader = fieldReader;
                Constructor = constructor;
            }
 
            public string Name { get; }
            public int Tag { get; }
            public PropertyInfo[] Fields { get; }
            public string? JsonPropertyName { get; }
            public Func<object, object[]> FieldReader { get; }
            public Func<object[], object> Constructor { get; }
            public bool IsFieldless => Fields.Length == 0;
        }
    }
}