File: System\Text\Json\Serialization\Metadata\DefaultJsonTypeInfoResolver.Union.cs
Web Access
Project: src\src\runtime\src\libraries\System.Text.Json\src\System.Text.Json.csproj (System.Text.Json)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text.Json.Reflection;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Converters;

namespace System.Text.Json.Serialization.Metadata
{
    public partial class DefaultJsonTypeInfoResolver
    {
        [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
        [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
        internal static void PopulateUnionMetadata(JsonTypeInfo typeInfo)
        {
            Debug.Assert(!typeInfo.IsReadOnly);
            Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Union);

            Type unionType = typeInfo.Type;

            Type builderType = typeof(UnionMetadataBuilder<>).MakeGenericType(unionType);
            var builder = (UnionMetadataBuilder)Activator.CreateInstance(builderType, nonPublic: true)!;
            builder.Build(typeInfo);
        }

        private abstract class UnionMetadataBuilder
        {
            [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
            [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
            public abstract void Build(JsonTypeInfo typeInfo);
        }

        private sealed class UnionMetadataBuilder<TUnion> : UnionMetadataBuilder
        {
            private readonly List<UnionCaseEntry> _caseEntries = new();
            private readonly Dictionary<Type, UnionCaseEntry> _entryByCaseType = new();

            [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
            [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
            public override void Build(JsonTypeInfo typeInfo)
            {
                JsonTypeInfo<TUnion> typeInfoOfT = (JsonTypeInfo<TUnion>)typeInfo;

                PopulateUnionCases(typeInfoOfT);
                if (_caseEntries.Count == 0)
                {
                    // No discoverable case constructors, return early and leave the type info with an empty
                    // case list and null delegates for potential user-side contract fix-up.
                    return;
                }

                PopulateUnionTypeClassifier(typeInfoOfT); // Must happen after union case population.
                PopulateUnionDelegates(typeInfoOfT);
            }

            private static void PopulateUnionTypeClassifier(JsonTypeInfo<TUnion> typeInfo)
            {
                Debug.Assert(typeInfo.TypeClassifier is null,
                    "PopulateTypeClassifier is only invoked from the built-in resolver, before any contract customization. " +
                    "TypeClassifier must therefore not be set yet.");

                JsonUnionAttribute? attr = typeof(TUnion).GetCustomAttribute<JsonUnionAttribute>();
                if (attr?.TypeClassifier is { } attrClassifierType)
                {
                    if (!typeof(JsonTypeClassifierFactory).IsAssignableFrom(attrClassifierType))
                    {
                        ThrowHelper.ThrowInvalidOperationException_TypeClassifierMustDeriveFromJsonTypeClassifierFactory(attrClassifierType, typeof(TUnion));
                    }

                    typeInfo.TypeClassifierFactory = (JsonTypeClassifierFactory)Activator.CreateInstance(attrClassifierType)!;
                }

                // Resolution is deferred to first read of
                // JsonTypeInfo.TypeClassifier — at that point the typeInfo is in the
                // per-options cache, so re-entrant lookups for the union type itself find the
                // partial typeInfo instead of recursing into a fresh resolution.
                typeInfo.TypeClassifierResolutionPending = true;
            }

            /// <summary>
            /// Walks the union type's public single-parameter constructors and populates
            /// <see cref="JsonTypeInfo.UnionCases"/> in declaration order. When the same case
            /// type appears across multiple constructors, the entry is added once, but the
            /// <see cref="JsonUnionCaseInfo.IsNullable"/> flag is the OR across all matching
            /// constructors so that any nullable-accepting overload is selected for the case.
            /// </summary>
            [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
            [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
            private void PopulateUnionCases(JsonTypeInfo<TUnion> typeInfo)
            {
                Debug.Assert(typeInfo.UnionCases.Count == 0,
                    "PopulateUnionCases is only invoked from the built-in resolver, before any contract customization. " +
                    "UnionCases must therefore not be populated yet.");

                NullabilityInfoContext nullabilityCtx = new();
                IList<JsonUnionCaseInfo> unionCases = typeInfo.UnionCases;
                Dictionary<Type, int> indexByCaseType = new();

                foreach (ConstructorInfo ctor in typeof(TUnion).GetConstructors(BindingFlags.Public | BindingFlags.Instance))
                {
                    ParameterInfo[] parameters = ctor.GetParameters();
                    if (parameters.Length != 1 ||
                        !TryGetCaseType(parameters[0], nullabilityCtx, out Type? paramType, out bool acceptsNull) ||
                        paramType.GetCustomAttribute<CompilerGeneratedAttribute>() is not null)
                    {
                        continue;
                    }

                    if (_entryByCaseType.TryGetValue(paramType, out UnionCaseEntry? entry))
                    {
                        // Reachable when a value-type case has both `Foo(T)` and `Foo(T?)` ctor overloads.
                        Debug.Assert(paramType.IsValueType);

                        if (acceptsNull && !entry.CaseInfo.IsNullable)
                        {
                            int index = indexByCaseType[paramType];
                            entry = CreateUnionCaseEntry(paramType, ctor, isNullable: true);
                            _entryByCaseType[paramType] = entry;
                            _caseEntries[index] = entry;
                            unionCases[index] = entry.CaseInfo;
                        }
                    }
                    else
                    {
                        entry = CreateUnionCaseEntry(paramType, ctor, acceptsNull);
                        _entryByCaseType.Add(paramType, entry);
                        indexByCaseType.Add(paramType, _caseEntries.Count);
                        _caseEntries.Add(entry);
                        unionCases.Add(entry.CaseInfo);
                    }
                }
            }

            /// <summary>
            /// Builds the convention-based <see cref="JsonTypeInfo.UnionDeconstructor"/> and
            /// <see cref="JsonTypeInfo.UnionConstructor"/> delegates.
            /// </summary>
            /// <remarks>
            /// Convention-based discovery gets the union value from a public instance
            /// <c>object Value</c> property. If the property is absent, the deconstructor is left
            /// null and the user must populate it via contract customization.
            /// </remarks>
            [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
            [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
            private void PopulateUnionDelegates(JsonTypeInfo<TUnion> typeInfo)
            {
                Debug.Assert(_caseEntries.Count > 0);

                PropertyInfo? valueProperty = GetUnionValueProperty();
                if (valueProperty is null)
                {
                    return;
                }

                // Topologically sort the (declaration-ordered) UnionCases — most-derived first —
                // so the nearest-ancestor walk in the deconstructor and constructor delegates
                // hits the most-specific declared case before any of its bases.
                UnionCaseEntry[] orderedCases = BuildTopologicallySortedCaseEntries();
                ConcurrentDictionary<Type, UnionCaseEntry?> caseIndex = CreateCaseIndex(orderedCases);
                UnionCaseEntry? nullableCase = null;
                foreach (UnionCaseEntry entry in orderedCases)
                {
                    if (entry.CaseInfo.IsNullable)
                    {
                        nullableCase = entry;
                        break;
                    }
                }

                PopulateUnionDeconstructor(typeInfo, orderedCases, caseIndex, valueProperty, nullableCase);
                PopulateUnionConstructor(typeInfo, orderedCases, caseIndex);
            }

            private UnionCaseEntry[] BuildTopologicallySortedCaseEntries()
            {
                Type[] caseTypes = new Type[_caseEntries.Count];
                for (int i = 0; i < _caseEntries.Count; i++)
                {
                    caseTypes[i] = _caseEntries[i].CaseType;
                }

                Type[] orderedCaseTypes = SortTypesByInheritanceHierarchy(caseTypes, mostDerivedTypesFirst: true);
                UnionCaseEntry[] orderedCases = new UnionCaseEntry[orderedCaseTypes.Length];
                for (int i = 0; i < orderedCaseTypes.Length; i++)
                {
                    orderedCases[i] = _entryByCaseType[orderedCaseTypes[i]];
                }

                return orderedCases;
            }

            private static ConcurrentDictionary<Type, UnionCaseEntry?> CreateCaseIndex(UnionCaseEntry[] orderedCases)
            {
                var caseIndex = new ConcurrentDictionary<Type, UnionCaseEntry?>();
                foreach (UnionCaseEntry entry in orderedCases)
                {
                    caseIndex[entry.CaseType] = entry;
                }

                return caseIndex;
            }

            [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
            [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
            private static void PopulateUnionDeconstructor(
                JsonTypeInfo<TUnion> typeInfo,
                UnionCaseEntry[] orderedCases,
                ConcurrentDictionary<Type, UnionCaseEntry?> caseIndex,
                PropertyInfo valueProperty,
                UnionCaseEntry? nullableCase)
            {
                Debug.Assert(typeInfo.UnionDeconstructor is null);

                Func<TUnion, object?> valueAccessor = MemberAccessor.CreatePropertyGetter<TUnion, object?>(valueProperty);
                UnionTryGetValueAccessor<TUnion>? chainedTryGetValue = PopulateTryGetValueMethod();

                typeInfo.UnionDeconstructor = (TUnion union) =>
                {
                    // For reference-type unions union may be null -- treat as the canonical null state.
                    if (!typeof(TUnion).IsValueType && (object?)union is null)
                    {
                        return (null, null);
                    }

                    // Primary path: when the union declares 'bool TryGetValue(out CaseType)'
                    // overloads, defer to the chained accessor which mirrors the C# compiler's
                    // pattern-matching lowering (overloads are tried most-derived-first; first
                    // true wins). A false return falls through to the default ResolveUnionCase
                    // path below so unions with partial TryGetValue coverage still dispatch the
                    // remaining cases by runtime type.
                    if (chainedTryGetValue is not null && chainedTryGetValue(union, out Type? matchedCaseType, out object? matchedValue))
                    {
                        return (matchedCaseType, matchedValue);
                    }

                    object? value = valueAccessor(union);
                    if (value is null)
                    {
                        if (nullableCase is null)
                        {
                            ThrowHelper.ThrowJsonException_UnionDoesNotAcceptNull(typeof(TUnion));
                        }

                        return (nullableCase.CaseType, null);
                    }

                    Type runtimeType = value.GetType();
                    UnionCaseEntry? entry = ResolveUnionCase(caseIndex, orderedCases, runtimeType);
                    if (entry is null)
                    {
                        ThrowHelper.ThrowJsonException_UnionRuntimeTypeNotMatchedToCase(typeof(TUnion), runtimeType);
                    }

                    return (entry.CaseType, value);
                };

                // Discovers public instance 'bool TryGetValue(out CaseType)' overloads on TUnion
                // matching declared case types and folds them into a single chained accessor
                // delegate. C# pattern matching for [Union] types lowers 'v is CaseType' to a
                // call to such overloads when present, so the reflection deconstructor must
                // honor the same convention to stay consistent with the source generator.
                [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
                [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
                UnionTryGetValueAccessor<TUnion>? PopulateTryGetValueMethod()
                {
                    KeyValuePair<Type, MethodInfo>[]? entries = PopulateTryGetValueMethods();
                    return entries is null
                        ? null
                        : MemberAccessor.Instance.CreateUnionTryGetValueAccessor<TUnion>(entries);

                    // Filter by the literal "TryGetValue" name in the reflection query: it cuts
                    // the candidate set down to the overloads we actually care about and also
                    // gives the IL trimmer a static signal that this method name is reflected
                    // over, so it can root the matching overloads instead of every public
                    // instance method on TUnion. The discovered overloads are then ordered
                    // most-derived-first via topological sort so when multiple of them can match
                    // the same instance the nearest declared case wins (mirrors the C#
                    // compiler's pattern-matching lowering on union types).
                    [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
                    [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
                    KeyValuePair<Type, MethodInfo>[]? PopulateTryGetValueMethods()
                    {
                        Dictionary<Type, MethodInfo>? methodsByCaseType = null;
                        foreach (MemberInfo member in typeof(TUnion).GetMember("TryGetValue", MemberTypes.Method, BindingFlags.Public | BindingFlags.Instance))
                        {
                            MethodInfo method = (MethodInfo)member;
                            if (method.ReturnType != typeof(bool) || method.IsGenericMethodDefinition)
                            {
                                continue;
                            }

                            ParameterInfo[] parameters = method.GetParameters();
                            if (parameters.Length != 1)
                            {
                                continue;
                            }

                            ParameterInfo parameter = parameters[0];
                            if (!parameter.IsOut || !parameter.ParameterType.IsByRef)
                            {
                                continue;
                            }

                            Type caseType = parameter.ParameterType.GetElementType()!;
                            if (!caseIndex.ContainsKey(caseType))
                            {
                                continue;
                            }

                            // First overload per case type wins; ignore later duplicates.
                            (methodsByCaseType ??= new()).TryAdd(caseType, method);
                        }

                        if (methodsByCaseType is null)
                        {
                            return null;
                        }

                        Type[] orderedCaseTypes = SortTypesByInheritanceHierarchy(
                            new List<Type>(methodsByCaseType.Keys).ToArray(),
                            mostDerivedTypesFirst: true);

                        KeyValuePair<Type, MethodInfo>[] orderedEntries = new KeyValuePair<Type, MethodInfo>[orderedCaseTypes.Length];
                        for (int i = 0; i < orderedCaseTypes.Length; i++)
                        {
                            Type caseType = orderedCaseTypes[i];
                            orderedEntries[i] = new KeyValuePair<Type, MethodInfo>(caseType, methodsByCaseType[caseType]);
                        }

                        return orderedEntries;
                    }
                }
            }

            private void PopulateUnionConstructor(
                JsonTypeInfo<TUnion> typeInfo,
                UnionCaseEntry[] orderedCases,
                ConcurrentDictionary<Type, UnionCaseEntry?> caseIndex)
            {
                Debug.Assert(typeInfo.UnionConstructor is null,
                    "PopulateUnionConstructor is only invoked from the built-in resolver, before any contract customization.");

                Func<object?, TUnion>? nullConstructor = null;
                foreach (UnionCaseEntry entry in _caseEntries)
                {
                    if (entry.CaseInfo.IsNullable)
                    {
                        nullConstructor = entry.Constructor;
                        break;
                    }
                }

                typeInfo.UnionConstructor = (Type caseType, object? value) =>
                {
                    if (value is null)
                    {
                        if (nullConstructor is null)
                        {
                            ThrowHelper.ThrowJsonException_UnionDoesNotAcceptNull(typeof(TUnion));
                        }

                        return nullConstructor(null);
                    }

                    UnionCaseEntry? entry = ResolveUnionCase(caseIndex, orderedCases, caseType);
                    if (entry is null)
                    {
                        ThrowHelper.ThrowJsonException_UnionRuntimeTypeNotMatchedToCase(typeof(TUnion), caseType);
                    }

                    return entry.Constructor(value);
                };
            }

            private static UnionCaseEntry? ResolveUnionCase(
                ConcurrentDictionary<Type, UnionCaseEntry?> caseIndex,
                UnionCaseEntry[] orderedCases,
                Type runtimeType)
            {
                if (caseIndex.TryGetValue(runtimeType, out UnionCaseEntry? cached))
                {
                    return cached;
                }

                // orderedCases is topologically sorted (most-derived first), so the first
                // ancestor match is also the nearest one.
                UnionCaseEntry? found = null;
                foreach (UnionCaseEntry entry in orderedCases)
                {
                    if (entry.CaseType.IsAssignableFrom(runtimeType))
                    {
                        found = entry;
                        break;
                    }
                }

                caseIndex[runtimeType] = found;
                return found;
            }

            [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
            [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
            private static PropertyInfo? GetUnionValueProperty()
            {
                PropertyInfo? valueProperty = typeof(TUnion).GetProperty("Value", BindingFlags.Public | BindingFlags.Instance);
                return valueProperty is { PropertyType: { } propertyType } &&
                    propertyType == typeof(object) &&
                    valueProperty.GetMethod is { IsPublic: true } &&
                    valueProperty.GetIndexParameters().Length == 0
                    ? valueProperty
                    : null;
            }

            [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)]
            [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)]
            private static UnionCaseEntry CreateUnionCaseEntry(Type caseType, ConstructorInfo constructorInfo, bool isNullable)
            {
                Func<object?, TUnion> constructor = MemberAccessor.CreateSingleParameterConstructor<TUnion>(constructorInfo);
                return new UnionCaseEntry(caseType, constructor, isNullable);
            }

            private sealed class UnionCaseEntry
            {
                public UnionCaseEntry(Type caseType, Func<object?, TUnion> constructor, bool isNullable)
                {
                    CaseType = caseType;
                    Constructor = constructor;
                    CaseInfo = new JsonUnionCaseInfo(caseType) { IsNullable = isNullable };
                }

                public Type CaseType { get; }
                public JsonUnionCaseInfo CaseInfo { get; }
                public Func<object?, TUnion> Constructor { get; }
            }
        }

        private static Type[] SortTypesByInheritanceHierarchy(Type[] types, bool mostDerivedTypesFirst)
        {
            if (types.Length <= 1)
            {
                return types;
            }

            Type root = typeof(void);
            Debug.Assert(Array.IndexOf(types, root) < 0);

            // Use typeof(void) as a synthetic root: it cannot be a case type and is not
            // in a subtype relationship with any valid case type.
            Type[] sortedTypesWithRoot = JsonHelpers.TraverseGraphWithTopologicalSort(root, GetInheritanceRelatedTypes);
            Debug.Assert(sortedTypesWithRoot.Length == types.Length + 1);
            Debug.Assert(sortedTypesWithRoot[0] == root);

            Type[] sortedTypes = new Type[types.Length];
            Array.Copy(sortedTypesWithRoot, sourceIndex: 1, sortedTypes, destinationIndex: 0, sortedTypes.Length);

            return sortedTypes;

            ICollection<Type> GetInheritanceRelatedTypes(Type type)
            {
                if (type == root)
                {
                    Type[] rootChildren = new Type[types.Length];
                    for (int i = 0; i < rootChildren.Length; i++)
                    {
                        // TraverseGraphWithTopologicalSort writes childless nodes from the
                        // end of the result, so enumerate root children in reverse to
                        // preserve the input order for unrelated types.
                        rootChildren[i] = types[rootChildren.Length - i - 1];
                    }

                    return rootChildren;
                }

                List<Type>? relatedTypes = null;
                foreach (Type candidate in types)
                {
                    bool isRelatedType = mostDerivedTypesFirst
                        ? candidate.IsAssignableFrom(type)
                        : type.IsAssignableFrom(candidate);

                    if (candidate != type && isRelatedType)
                    {
                        (relatedTypes ??= new()).Add(candidate);
                    }
                }

                return relatedTypes ?? (ICollection<Type>)Array.Empty<Type>();
            }
        }

        private static bool TryGetCaseType(
            ParameterInfo parameter,
            NullabilityInfoContext nullabilityCtx,
            [NotNullWhen(true)] out Type? caseType,
            out bool acceptsNull)
        {
            acceptsNull = false;
            Type parameterType = parameter.ParameterType;
            if (parameterType.IsByRef)
            {
                caseType = null;
                return false;
            }

            caseType = parameterType;
            if (Nullable.GetUnderlyingType(caseType) is Type underlying)
            {
                caseType = underlying;
            }

            if (parameterType.IsNullableType())
            {
                NullabilityInfo nullability = nullabilityCtx.Create(parameter);
                acceptsNull = nullability.WriteState is not NullabilityState.NotNull;
            }

            return true;
        }
    }
}