File: Compiler\Dataflow\ReflectionMarker.cs
Web Access
Project: src\src\runtime\src\coreclr\tools\aot\ILCompiler.Compiler\ILCompiler.Compiler.csproj (ILCompiler.Compiler)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Reflection.Metadata;
using ILCompiler.DependencyAnalysis;
using ILCompiler.Logging;
using ILLink.Shared;
using ILLink.Shared.TrimAnalysis;
using Internal.TypeSystem;

using DependencyList = ILCompiler.DependencyAnalysisFramework.DependencyNodeCore<ILCompiler.DependencyAnalysis.NodeFactory>.DependencyList;

#nullable enable
#pragma warning disable IDE0060

namespace ILCompiler.Dataflow
{
    public class ReflectionMarker
    {
        private DependencyList _dependencies = new DependencyList();
        private readonly Logger _logger;
        private readonly MetadataType? _typeHierarchyDataFlowOrigin;
        private readonly bool _enabled;

        public NodeFactory Factory { get; }
        public FlowAnnotations Annotations { get; }
        public DependencyList Dependencies { get => _dependencies; }
        public List<(MethodDesc OwningMethod, INodeWithRuntimeDeterminedDependencies Dependency)> RuntimeDeterminedDependencies { get; } = new List<(MethodDesc, INodeWithRuntimeDeterminedDependencies)>();

        internal enum AccessKind
        {
            Unspecified,
            DynamicallyAccessedMembersMark,
            TokenAccess
        }

        public ReflectionMarker(Logger logger, NodeFactory factory, FlowAnnotations annotations, MetadataType? typeHierarchyDataFlowOrigin, bool enabled)
        {
            _logger = logger;
            Factory = factory;
            Annotations = annotations;
            _typeHierarchyDataFlowOrigin = typeHierarchyDataFlowOrigin;
            _enabled = enabled;
        }

        internal void MarkTypeForDynamicallyAccessedMembers(in MessageOrigin origin, TypeDesc typeDefinition, DynamicallyAccessedMemberTypes requiredMemberTypes, TypeSystemEntity reason, bool declaredOnly = false)
        {
            if (!_enabled)
                return;

            string displayName = reason.GetDisplayName();
            foreach (var member in typeDefinition.GetDynamicallyAccessedMembers(requiredMemberTypes, declaredOnly))
            {
                MarkTypeSystemEntity(origin, member, displayName, AccessKind.DynamicallyAccessedMembersMark);
            }
        }

        internal void MarkTypeSystemEntity(in MessageOrigin origin, TypeSystemEntity entity, string reason, AccessKind accessKind = AccessKind.Unspecified)
        {
            switch (entity)
            {
                case MethodDesc method:
                    MarkMethod(origin, method, reason, accessKind);
                    break;
                case FieldDesc field:
                    MarkField(origin, field, reason, accessKind);
                    break;
                case MetadataType nestedType:
                    MarkType(origin, nestedType, reason, accessKind);
                    break;
                case PropertyPseudoDesc property:
                    MarkProperty(origin, property, reason, accessKind);
                    break;
                case EventPseudoDesc @event:
                    MarkEvent(origin, @event, reason, accessKind);
                    break;
                    // case InterfaceImplementation
                    //  This is handled in the MetadataType case above
            }
        }

        internal bool TryResolveTypeNameAndMark(string typeName, in DiagnosticContext diagnosticContext, bool needsAssemblyName, TypeSystemEntity reason, [NotNullWhen(true)] out TypeDesc? type)
        {
            ModuleDesc? callingModule = (diagnosticContext.Origin.MemberDefinition.GetOwningType() as MetadataType)?.Module;

            List<ModuleDesc> referencedModules = new();
            TypeDesc foundType = CustomAttributeTypeNameParser.GetTypeByCustomAttributeTypeNameForDataFlow(typeName, callingModule, diagnosticContext.Origin.MemberDefinition!.Context,
                referencedModules, needsAssemblyName, fallbackToCoreLib: true, out bool failedBecauseNotFullyQualified);
            if (foundType == null)
            {
                if (failedBecauseNotFullyQualified)
                {
                    diagnosticContext.AddDiagnostic(DiagnosticId.TypeNameIsNotAssemblyQualified, typeName);
                }
                type = default;
                return false;
            }

            if (_enabled)
            {
                string displayName = reason.GetDisplayName();
                // Also add module metadata in case this reference was through a type forward
                // TODO-ILTRIM: add handling of type forwards
#if !ILTRIM
                foreach (ModuleDesc referencedModule in referencedModules)
                {
                    if (Factory.MetadataManager.CanGenerateMetadata(referencedModule.GetGlobalModuleType()))
                        _dependencies.Add(Factory.ModuleMetadata(referencedModule), displayName);
                }
#endif

                MarkType(diagnosticContext.Origin, foundType, displayName);
            }

            type = foundType;
            return true;
        }

        internal bool TryResolveTypeNameAndMark(ModuleDesc assembly, string typeName, in DiagnosticContext diagnosticContext, string reason, bool fallbackToCoreLib, [NotNullWhen(true)] out TypeDesc? type)
        {
            List<ModuleDesc> referencedModules = new();
            TypeDesc foundType = CustomAttributeTypeNameParser.GetTypeByCustomAttributeTypeNameForDataFlow(typeName, assembly, assembly.Context,
                referencedModules, needsAssemblyName: false, fallbackToCoreLib, out _);
            if (foundType == null)
            {
                type = default;
                return false;
            }

            if (_enabled)
            {
                // Also add module metadata in case this reference was through a type forward
                // TODO-ILTRIM: add handling of type forwards
#if !ILTRIM
                foreach (ModuleDesc referencedModule in referencedModules)
                {
                    if (Factory.MetadataManager.CanGenerateMetadata(referencedModule.GetGlobalModuleType()))
                        _dependencies.Add(Factory.ModuleMetadata(referencedModule), reason);
                }
#endif

                MarkType(diagnosticContext.Origin, foundType, reason);
            }

            type = foundType;
            return true;
        }

        internal void MarkType(in MessageOrigin origin, TypeDesc type, TypeSystemEntity reason, AccessKind accessKind = AccessKind.Unspecified)
        {
            if (!_enabled)
                return;

            MarkType(origin, type, reason.GetDisplayName(), accessKind);
        }

        internal void MarkType(in MessageOrigin origin, TypeDesc type, string reason, AccessKind accessKind = AccessKind.Unspecified)
        {
            if (!_enabled)
                return;

            RootingHelpers.TryGetDependenciesForReflectedType(ref _dependencies, Factory, type, reason);
        }

        internal void MarkMethod(in MessageOrigin origin, MethodDesc method, TypeSystemEntity reason, AccessKind accessKind = AccessKind.Unspecified)
        {
            if (!_enabled)
                return;

            MarkMethod(origin, method, reason.GetDisplayName(), accessKind);
        }

        internal void MarkMethod(in MessageOrigin origin, MethodDesc method, string reason, AccessKind accessKind = AccessKind.Unspecified)
        {
            if (!_enabled)
                return;

            CheckAndWarnOnReflectionAccess(origin, method, accessKind);

            RootingHelpers.TryGetDependenciesForReflectedMethod(ref _dependencies, Factory, method, reason);
        }

        internal void MarkField(in MessageOrigin origin, FieldDesc field, string reason, AccessKind accessKind = AccessKind.Unspecified)
        {
            if (!_enabled)
                return;

            CheckAndWarnOnReflectionAccess(origin, field, accessKind);

            RootingHelpers.TryGetDependenciesForReflectedField(ref _dependencies, Factory, field, reason);
        }

        internal void MarkProperty(in MessageOrigin origin, PropertyPseudoDesc property, TypeSystemEntity reason, AccessKind accessKind = AccessKind.Unspecified)
        {
            if (!_enabled)
                return;

            MarkProperty(origin, property, reason.GetDisplayName(), accessKind);
        }

        internal void MarkProperty(in MessageOrigin origin, PropertyPseudoDesc property, string reason, AccessKind accessKind = AccessKind.Unspecified)
        {
            if (!_enabled)
                return;

            if (property.GetMethod != null)
                MarkMethod(origin, property.GetMethod, reason);
            if (property.SetMethod != null)
                MarkMethod(origin, property.SetMethod, reason);
        }

        private void MarkEvent(in MessageOrigin origin, EventPseudoDesc @event, string reason, AccessKind accessKind = AccessKind.Unspecified)
        {
            if (!_enabled)
                return;

            if (@event.AddMethod != null)
                MarkMethod(origin, @event.AddMethod, reason);
            if (@event.RemoveMethod != null)
                MarkMethod(origin, @event.RemoveMethod, reason);
        }

        internal void MarkConstructorsOnType(in MessageOrigin origin, TypeDesc type, Func<MethodDesc, bool>? filter, TypeSystemEntity reason, BindingFlags? bindingFlags = null)
        {
            if (!_enabled)
                return;

            string displayName = reason.GetDisplayName();
            foreach (var ctor in type.GetConstructorsOnType(filter, bindingFlags))
                MarkMethod(origin, ctor, displayName);
        }

        internal void MarkFieldsOnTypeHierarchy(in MessageOrigin origin, TypeDesc type, Func<FieldDesc, bool> filter, TypeSystemEntity reason, BindingFlags? bindingFlags = BindingFlags.Default)
        {
            if (!_enabled)
                return;

            string displayName = reason.GetDisplayName();
            foreach (var field in type.GetFieldsOnTypeHierarchy(filter, bindingFlags))
                MarkField(origin, field, displayName);
        }

        internal void MarkPropertiesOnTypeHierarchy(in MessageOrigin origin, TypeDesc type, Func<PropertyPseudoDesc, bool> filter, TypeSystemEntity reason, BindingFlags? bindingFlags = BindingFlags.Default)
        {
            if (!_enabled)
                return;

            string displayName = reason.GetDisplayName();
            foreach (var property in type.GetPropertiesOnTypeHierarchy(filter, bindingFlags))
                MarkProperty(origin, property, displayName);
        }

        internal void MarkEventsOnTypeHierarchy(in MessageOrigin origin, TypeDesc type, Func<EventPseudoDesc, bool> filter, TypeSystemEntity reason, BindingFlags? bindingFlags = BindingFlags.Default)
        {
            if (!_enabled)
                return;

            string displayName = reason.GetDisplayName();
            foreach (var @event in type.GetEventsOnTypeHierarchy(filter, bindingFlags))
                MarkEvent(origin, @event, displayName);
        }

        internal void MarkStaticConstructor(in MessageOrigin origin, TypeDesc type, TypeSystemEntity reason)
        {
            if (!_enabled)
                return;

            MethodDesc cctor = type.GetStaticConstructor();
            if (cctor != null)
                MarkMethod(origin, cctor, reason.GetDisplayName());
        }

        internal void CheckAndWarnOnReflectionAccess(in MessageOrigin origin, TypeSystemEntity entity, AccessKind accessKind = AccessKind.Unspecified)
        {
            if (!_enabled)
                return;

            if (_typeHierarchyDataFlowOrigin is not null)
            {
                ReportWarningsForTypeHierarchyReflectionAccess(origin, entity);
            }
            else
            {
                ReportWarningsForReflectionAccess(origin, entity, accessKind);
            }
        }

        private void ReportWarningsForReflectionAccess(in MessageOrigin origin, TypeSystemEntity entity, AccessKind accessKind)
        {
            Debug.Assert(entity is MethodDesc or FieldDesc);

            // Note that we're using `ShouldSuppressAnalysisWarningsForRequires` instead of `DoesMemberRequire`.
            // This is because reflection access is actually problematic on all members which are in a "requires" scope
            // so for example even instance methods. See for example https://github.com/dotnet/linker/issues/3140 - it's possible
            // to call a method on a "null" instance via reflection.
            if (_logger.ShouldSuppressAnalysisWarningsForRequires(entity, DiagnosticUtilities.RequiresUnreferencedCodeAttribute, out CustomAttributeValue<TypeDesc>? requiresAttribute) &&
                ShouldProduceRequiresWarningForReflectionAccess(entity, accessKind))
                    ReportRequires(origin, entity, DiagnosticUtilities.RequiresUnreferencedCodeAttribute, requiresAttribute.Value);

            if (_logger.ShouldSuppressAnalysisWarningsForRequires(entity, DiagnosticUtilities.RequiresAssemblyFilesAttribute, out requiresAttribute) &&
                ShouldProduceRequiresWarningForReflectionAccess(entity, accessKind))
                    ReportRequires(origin, entity, DiagnosticUtilities.RequiresAssemblyFilesAttribute, requiresAttribute.Value);

            if (_logger.ShouldSuppressAnalysisWarningsForRequires(entity, DiagnosticUtilities.RequiresDynamicCodeAttribute, out requiresAttribute) &&
                ShouldProduceRequiresWarningForReflectionAccess(entity, accessKind))
                    ReportRequires(origin, entity, DiagnosticUtilities.RequiresDynamicCodeAttribute, requiresAttribute.Value);

            // Below is about accessing DAM annotated members, so only RUC is applicable as a suppression scope
            if (_logger.ShouldSuppressAnalysisWarningsForRequires(origin.MemberDefinition, DiagnosticUtilities.RequiresUnreferencedCodeAttribute))
                return;

            bool isReflectionAccessCoveredByDAM = Annotations.ShouldWarnWhenAccessedForReflection(entity);
            if (isReflectionAccessCoveredByDAM && ShouldProduceRequiresWarningForReflectionAccess(entity, accessKind))
            {
                if (entity is MethodDesc)
                    _logger.LogWarning(origin, DiagnosticId.DynamicallyAccessedMembersMethodAccessedViaReflection, entity.GetDisplayName());
                else
                    _logger.LogWarning(origin, DiagnosticId.DynamicallyAccessedMembersFieldAccessedViaReflection, entity.GetDisplayName());
            }

            // We decided to not warn on reflection access to compiler-generated methods:
            // https://github.com/dotnet/runtime/issues/85042

            static bool ShouldProduceRequiresWarningForReflectionAccess(TypeSystemEntity entity, AccessKind accessKind)
            {
                bool isCompilerGenerated = CompilerGeneratedState.IsNestedFunctionOrStateMachineMember(entity);

                // Compiler generated code accessed via a token is considered a "hard" reference
                // even though we also have to treat it as reflection access.
                // So we need to enforce RUC check/warn in this case.
                bool forceRequiresWarning = accessKind == AccessKind.TokenAccess;

                return !isCompilerGenerated || forceRequiresWarning;
            }
        }

        private void ReportWarningsForTypeHierarchyReflectionAccess(MessageOrigin origin, TypeSystemEntity entity)
        {
            Debug.Assert(entity is MethodDesc or FieldDesc);

            // Don't check whether the current scope is a RUC type or RUC method because these warnings
            // are not suppressed in RUC scopes. Here the scope represents the DynamicallyAccessedMembers
            // annotation on a type, not a callsite which uses the annotation. We always want to warn about
            // possible reflection access indicated by these annotations.

            Debug.Assert(_typeHierarchyDataFlowOrigin != null);

            static bool IsDeclaredWithinType(TypeSystemEntity member, TypeDesc type)
            {
                TypeDesc owningType = member.GetOwningType();
                while (owningType != null)
                {
                    if (owningType == type)
                        return true;

                    owningType = owningType.GetOwningType();
                }
                return false;
            }

            var reportOnMember = IsDeclaredWithinType(entity, _typeHierarchyDataFlowOrigin);
            if (reportOnMember)
                origin = new MessageOrigin(entity);

            // For now we decided to not report single-file or dynamic-code warnings due to type hierarchy marking.
            // It is considered too complex to figure out for the user and the likelihood of this
            // causing problems is pretty low.

            bool isReflectionAccessCoveredByRUC = _logger.ShouldSuppressAnalysisWarningsForRequires(entity, DiagnosticUtilities.RequiresUnreferencedCodeAttribute, out CustomAttributeValue<TypeDesc>? requiresUnreferencedCodeAttribute);
            bool isCompilerGenerated = CompilerGeneratedState.IsNestedFunctionOrStateMachineMember(entity);
            if (isReflectionAccessCoveredByRUC && !isCompilerGenerated)
            {
                var id = reportOnMember ? DiagnosticId.DynamicallyAccessedMembersOnTypeReferencesMemberWithRequiresUnreferencedCode : DiagnosticId.DynamicallyAccessedMembersOnTypeReferencesMemberOnBaseWithRequiresUnreferencedCode;
                _logger.LogWarning(origin, id, _typeHierarchyDataFlowOrigin.GetDisplayName(),
                entity.GetDisplayName(),
                    MessageFormat.FormatRequiresAttributeMessageArg(DiagnosticUtilities.GetRequiresAttributeMessage(requiresUnreferencedCodeAttribute!.Value)),
                    MessageFormat.FormatRequiresAttributeMessageArg(DiagnosticUtilities.GetRequiresAttributeUrl(requiresUnreferencedCodeAttribute!.Value)));
            }

            bool isReflectionAccessCoveredByDAM = Annotations.ShouldWarnWhenAccessedForReflection(entity);
            if (isReflectionAccessCoveredByDAM && !isCompilerGenerated)
            {
                var id = reportOnMember ? DiagnosticId.DynamicallyAccessedMembersOnTypeReferencesMemberWithDynamicallyAccessedMembers : DiagnosticId.DynamicallyAccessedMembersOnTypeReferencesMemberOnBaseWithDynamicallyAccessedMembers;
                _logger.LogWarning(origin, id, _typeHierarchyDataFlowOrigin.GetDisplayName(), entity.GetDisplayName());
            }
        }

        private void ReportRequires(in MessageOrigin origin, TypeSystemEntity entity, string requiresAttributeName, in CustomAttributeValue<TypeDesc> requiresAttribute)
        {
            var diagnosticContext = new DiagnosticContext(
                origin,
                _logger.ShouldSuppressAnalysisWarningsForRequires(origin.MemberDefinition, DiagnosticUtilities.RequiresUnreferencedCodeAttribute),
                _logger.ShouldSuppressAnalysisWarningsForRequires(origin.MemberDefinition, DiagnosticUtilities.RequiresDynamicCodeAttribute),
                _logger.ShouldSuppressAnalysisWarningsForRequires(origin.MemberDefinition, DiagnosticUtilities.RequiresAssemblyFilesAttribute),
                _logger);

            ReflectionMethodBodyScanner.ReportRequires(diagnosticContext, entity, requiresAttributeName, requiresAttribute);
        }
    }
}