File: Linker.Dataflow\CompilerGeneratedState.cs
Web Access
Project: src\src\tools\illink\src\linker\Mono.Linker.csproj (illink)
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using ILLink.Shared;
using ILLink.Shared.DataFlow;
using Mono.Cecil;
using Mono.Cecil.Cil;
 
namespace Mono.Linker.Dataflow
{
    // Currently this is implemented using heuristics
    public class CompilerGeneratedState
    {
        readonly LinkContext _context;
        readonly Dictionary<TypeDefinition, MethodDefinition> _compilerGeneratedTypeToUserCodeMethod;
        readonly Dictionary<TypeDefinition, TypeArgumentInfo> _generatedTypeToTypeArgumentInfo;
        readonly record struct TypeArgumentInfo(
            /// <summary>The method which calls the ctor for the given type</summary>
            MethodDefinition CreatingMethod,
            /// <summary>Attributes for the type, pulled from the creators type arguments</summary>
            IReadOnlyList<ICustomAttributeProvider>? OriginalAttributes);
 
        readonly Dictionary<MethodDefinition, MethodDefinition> _compilerGeneratedMethodToUserCodeMethod;
 
        // For each type that has had its cache populated, stores a map of methods which have corresponding
        // compiler-generated members (either methods or state machine types) to those compiler-generated members,
        // or null if the type has no methods with compiler-generated members.
        readonly Dictionary<TypeDefinition, Dictionary<MethodDefinition, List<IMemberDefinition>>?> _cachedTypeToCompilerGeneratedMembers;
 
        public CompilerGeneratedState(LinkContext context)
        {
            _context = context;
            _compilerGeneratedTypeToUserCodeMethod = new Dictionary<TypeDefinition, MethodDefinition>();
            _generatedTypeToTypeArgumentInfo = new Dictionary<TypeDefinition, TypeArgumentInfo>();
            _compilerGeneratedMethodToUserCodeMethod = new Dictionary<MethodDefinition, MethodDefinition>();
            _cachedTypeToCompilerGeneratedMembers = new Dictionary<TypeDefinition, Dictionary<MethodDefinition, List<IMemberDefinition>>?>();
        }
 
        static IEnumerable<TypeDefinition> GetCompilerGeneratedNestedTypes(TypeDefinition type)
        {
            foreach (var nestedType in type.NestedTypes)
            {
                if (!CompilerGeneratedNames.IsStateMachineOrDisplayClass(nestedType.Name))
                    continue;
 
                yield return nestedType;
 
                foreach (var recursiveNestedType in GetCompilerGeneratedNestedTypes(nestedType))
                    yield return recursiveNestedType;
            }
        }
 
        public static bool IsHoistedLocal(FieldReference field)
        {
            if (CompilerGeneratedNames.IsLambdaDisplayClass(field.DeclaringType.Name))
                return true;
 
            if (CompilerGeneratedNames.IsStateMachineType(field.DeclaringType.Name))
            {
                // Don't track the "current" field which is used for state machine return values,
                // because this can be expensive to track.
                return !CompilerGeneratedNames.IsStateMachineCurrentField(field.Name);
            }
 
            return false;
        }
 
        // "Nested function" refers to lambdas and local functions.
        public static bool IsNestedFunctionOrStateMachineMember(IMemberDefinition member)
        {
            if (member is MethodDefinition method && CompilerGeneratedNames.IsLambdaOrLocalFunction(method.Name))
                return true;
 
            if (member.DeclaringType is not TypeDefinition declaringType)
                return false;
 
            return CompilerGeneratedNames.IsStateMachineType(declaringType.Name);
        }
 
        public static bool TryGetStateMachineType(MethodDefinition method, [NotNullWhen(true)] out TypeDefinition? stateMachineType)
        {
            stateMachineType = null;
            // Discover state machine methods.
            if (!method.HasCustomAttributes)
                return false;
 
            foreach (var attribute in method.CustomAttributes)
            {
                if (attribute.AttributeType.Namespace != "System.Runtime.CompilerServices")
                    continue;
 
                switch (attribute.AttributeType.Name)
                {
                    case "AsyncIteratorStateMachineAttribute":
                    case "AsyncStateMachineAttribute":
                    case "IteratorStateMachineAttribute":
                        stateMachineType = GetFirstConstructorArgumentAsType(attribute);
                        return stateMachineType != null;
                }
            }
            return false;
        }
 
        /// <summary>
        /// Walks the type and its descendents to find Roslyn-compiler generated
        /// code and gather information to map it back to original user code. If
        /// a compiler-generated type is passed in directly, this method will walk
        /// up and find the nearest containing user type. Returns the nearest user type,
        /// or null if none was found.
        /// </summary>
        TypeDefinition? GetCompilerGeneratedStateForType(TypeDefinition type)
        {
            // Look in the declaring type if this is a compiler-generated type (state machine or display class).
            // State machines can be emitted into display classes, so we may also need to go one more level up.
            // To avoid depending on implementation details, we go up until we see a non-compiler-generated type.
            // This is the counterpart to GetCompilerGeneratedNestedTypes.
            while (type != null && CompilerGeneratedNames.IsStateMachineOrDisplayClass(type.Name))
                type = type.DeclaringType;
 
            if (type is null)
                return null;
 
            // Avoid repeat scans of the same type
            if (_cachedTypeToCompilerGeneratedMembers.ContainsKey(type))
                return type;
 
            var callGraph = new CompilerGeneratedCallGraph();
            var userDefinedMethods = new HashSet<MethodDefinition>();
            var generatedTypeToTypeArgs = new Dictionary<TypeDefinition, TypeArgumentInfo>();
 
            List<(MessageOrigin, DiagnosticId, string[])>? _warnings = null;
 
            // We delay actually logging the warnings until the compiler-generated type info is
            // populated for this type, because the type info is needed to determine whether a warning
            // is suppressed.
            void AddWarning(MessageOrigin origin, DiagnosticId id, params string[] messageArgs)
            {
                _warnings ??= new List<(MessageOrigin, DiagnosticId, string[])>();
                _warnings.Add((origin, id, messageArgs));
            }
 
            void ProcessMethod(MethodDefinition method)
            {
                bool isStateMachineMember = CompilerGeneratedNames.IsStateMachineType(method.DeclaringType.Name);
                if (!CompilerGeneratedNames.IsLambdaOrLocalFunction(method.Name))
                {
                    if (!isStateMachineMember)
                    {
                        // If it's not a nested function, track as an entry point to the call graph.
                        var added = userDefinedMethods.Add(method);
                        Debug.Assert(added);
                    }
                }
                else
                {
                    // We don't expect lambdas or local functions to be emitted directly into
                    // state machine types.
                    Debug.Assert(!isStateMachineMember);
                }
 
                // Discover calls or references to lambdas or local functions. This includes
                // calls to local functions, and lambda assignments (which use ldftn).
                if (method.Body != null)
                {
                    foreach (var instruction in method.Body.Instructions(_context))
                    {
                        switch (instruction.OpCode.OperandType)
                        {
                            case OperandType.InlineMethod:
                            {
                                MethodDefinition? referencedMethod = _context.TryResolve((MethodReference)instruction.Operand);
                                if (referencedMethod == null)
                                    continue;
 
                                // Find calls to state machine constructors that occur outside the type
                                if (referencedMethod.IsConstructor &&
                                    referencedMethod.DeclaringType is var generatedType &&
                                    // Don't consider calls in the same type, like inside a static constructor
                                    method.DeclaringType != generatedType &&
                                    CompilerGeneratedNames.IsLambdaDisplayClass(generatedType.Name))
                                {
                                    // fill in null for now, attribute providers will be filled in later
                                    if (!generatedTypeToTypeArgs.TryAdd(generatedType, new TypeArgumentInfo(method, null)))
                                    {
                                        var alreadyAssociatedMethod = generatedTypeToTypeArgs[generatedType].CreatingMethod;
                                        AddWarning(new MessageOrigin(method), DiagnosticId.MethodsAreAssociatedWithUserMethod, method.GetDisplayName(), alreadyAssociatedMethod.GetDisplayName(), generatedType.GetDisplayName());
                                    }
                                    continue;
                                }
 
                                if (!CompilerGeneratedNames.IsLambdaOrLocalFunction(referencedMethod.Name))
                                    continue;
 
                                if (isStateMachineMember)
                                {
                                    callGraph.TrackCall(method.DeclaringType, referencedMethod);
                                }
                                else
                                {
                                    callGraph.TrackCall(method, referencedMethod);
                                }
                            }
                            break;
 
                            case OperandType.InlineField:
                            {
                                // Same as above, but stsfld instead of a call to the constructor
                                // Ldsfld may also trigger a cctor that creates a closure environment
                                if (instruction.OpCode.Code is not (Code.Stsfld or Code.Ldsfld))
                                    continue;
 
                                FieldDefinition? field = _context.TryResolve((FieldReference)instruction.Operand);
                                if (field == null)
                                    continue;
 
                                if (field.DeclaringType is var generatedType &&
                                    // Don't consider field accesses in the same type, like inside a static constructor
                                    method.DeclaringType != generatedType &&
                                    CompilerGeneratedNames.IsLambdaDisplayClass(generatedType.Name))
                                {
                                    if (!generatedTypeToTypeArgs.TryAdd(generatedType, new TypeArgumentInfo(method, null)))
                                    {
                                        // It's expected that there may be multiple methods associated with the same static closure environment.
                                        // All of these methods will substitute the same type arguments into the closure environment
                                        // (if it is generic). Don't warn.
                                    }
                                    continue;
                                }
                            }
                            break;
                        }
                    }
                }
 
                if (TryGetStateMachineType(method, out TypeDefinition? stateMachineType))
                {
                    Debug.Assert(stateMachineType.DeclaringType == type ||
                        CompilerGeneratedNames.IsStateMachineOrDisplayClass(stateMachineType.DeclaringType.Name) &&
                         stateMachineType.DeclaringType.DeclaringType == type);
                    callGraph.TrackCall(method, stateMachineType);
 
                    if (!_compilerGeneratedTypeToUserCodeMethod.TryAdd(stateMachineType, method))
                    {
                        var alreadyAssociatedMethod = _compilerGeneratedTypeToUserCodeMethod[stateMachineType];
                        AddWarning(new MessageOrigin(method), DiagnosticId.MethodsAreAssociatedWithStateMachine, method.GetDisplayName(), alreadyAssociatedMethod.GetDisplayName(), stateMachineType.GetDisplayName());
                    }
                    // Already warned above if multiple methods map to the same type
                    // Fill in null for argument providers now, the real providers will be filled in later
                    generatedTypeToTypeArgs[stateMachineType] = new TypeArgumentInfo(method, null);
                }
            }
 
            // Look for state machine methods, and methods which call local functions.
            foreach (MethodDefinition method in type.Methods)
                ProcessMethod(method);
 
            // Also scan compiler-generated state machine methods (in case they have calls to nested functions),
            // and nested functions inside compiler-generated closures (in case they call other nested functions).
 
            // State machines can be emitted into lambda display classes, so we need to go down at least two
            // levels to find calls from iterator nested functions to other nested functions. We just recurse into
            // all compiler-generated nested types to avoid depending on implementation details.
 
            foreach (var nestedType in GetCompilerGeneratedNestedTypes(type))
            {
                foreach (var method in nestedType.Methods)
                    ProcessMethod(method);
            }
 
            // Now we've discovered the call graphs for calls to nested functions.
            // Use this to map back from nested functions to the declaring user methods.
 
            // Note: This maps all nested functions back to the user code, not to the immediately
            // declaring local function. The IL doesn't contain enough information in general for
            // us to determine the nesting of local functions and lambdas.
 
            // Note: this only discovers nested functions which are referenced from the user
            // code or its referenced nested functions. There is no reliable way to determine from
            // IL which user code an unused nested function belongs to.
 
            Dictionary<MethodDefinition, List<IMemberDefinition>>? compilerGeneratedCallees = null;
            foreach (var userDefinedMethod in userDefinedMethods)
            {
                var callees = callGraph.GetReachableMembers(userDefinedMethod);
                if (!callees.Any())
                    continue;
 
                compilerGeneratedCallees ??= new Dictionary<MethodDefinition, List<IMemberDefinition>>();
                compilerGeneratedCallees.Add(userDefinedMethod, new List<IMemberDefinition>(callees));
 
                foreach (var compilerGeneratedMember in callees)
                {
                    switch (compilerGeneratedMember)
                    {
                        case MethodDefinition nestedFunction:
                            Debug.Assert(CompilerGeneratedNames.IsLambdaOrLocalFunction(nestedFunction.Name));
                            // Nested functions get suppressions from the user method only.
                            if (!_compilerGeneratedMethodToUserCodeMethod.TryAdd(nestedFunction, userDefinedMethod))
                            {
                                var alreadyAssociatedMethod = _compilerGeneratedMethodToUserCodeMethod[nestedFunction];
                                AddWarning(new MessageOrigin(userDefinedMethod), DiagnosticId.MethodsAreAssociatedWithUserMethod, userDefinedMethod.GetDisplayName(), alreadyAssociatedMethod.GetDisplayName(), nestedFunction.GetDisplayName());
                            }
                            break;
                        case TypeDefinition stateMachineType:
                            // Types in the call graph are always state machine types
                            // For those all their methods are not tracked explicitly in the call graph; instead, they
                            // are represented by the state machine type itself.
                            // We are already tracking the association of the state machine type to the user code method
                            // above, so no need to track it here.
                            Debug.Assert(CompilerGeneratedNames.IsStateMachineType(stateMachineType.Name));
                            break;
                        default:
                            throw new InvalidOperationException();
                    }
                }
            }
 
            // Now that we have instantiating methods fully filled out, walk the generated types and fill in the attribute
            // providers
            foreach (var generatedType in generatedTypeToTypeArgs.Keys)
            {
                if (generatedType.HasGenericParameters)
                {
                    MapGeneratedTypeTypeParameters(generatedType, generatedTypeToTypeArgs, _context);
                    // Finally, add resolved type arguments to the cache
                    var info = generatedTypeToTypeArgs[generatedType];
                    if (!_generatedTypeToTypeArgumentInfo.TryAdd(generatedType, info))
                    {
                        var method = info.CreatingMethod;
                        var alreadyAssociatedMethod = _generatedTypeToTypeArgumentInfo[generatedType].CreatingMethod;
                        AddWarning(new MessageOrigin(method), DiagnosticId.MethodsAreAssociatedWithUserMethod, method.GetDisplayName(), alreadyAssociatedMethod.GetDisplayName(), generatedType.GetDisplayName());
                    }
                }
            }
 
            _cachedTypeToCompilerGeneratedMembers.Add(type, compilerGeneratedCallees);
 
            if (_warnings != null)
            {
                foreach (var (origin, id, messageArgs) in _warnings)
                {
                    _context.LogWarning(origin, id, messageArgs);
                }
            }
 
            return type;
 
            /// <summary>
            /// Attempts to reverse the process of the compiler's alpha renaming. So if the original code was
            /// something like this:
            /// <code>
            /// void M&lt;T&gt; () {
            ///     Action a = () => { Console.WriteLine (typeof (T)); };
            /// }
            /// </code>
            /// The compiler will generate a nested class like this:
            /// <code>
            /// class &lt;&gt;c__DisplayClass0&lt;T&gt; {
            ///     public void &lt;M&gt;b__0 () {
            ///         Console.WriteLine (typeof (T));
            ///     }
            /// }
            /// </code>
            /// The task of this method is to figure out that the type parameter T in the nested class is the same
            /// as the type parameter T in the parent method M.
            /// <paramref name="generatedTypeToTypeArgs"/> acts as a memoization table to avoid recalculating the
            /// mapping multiple times.
            /// </summary>
            static void MapGeneratedTypeTypeParameters(
                TypeDefinition generatedType,
                Dictionary<TypeDefinition, TypeArgumentInfo> generatedTypeToTypeArgs,
                LinkContext context)
            {
                Debug.Assert(CompilerGeneratedNames.IsStateMachineOrDisplayClass(generatedType.Name));
 
                var typeInfo = generatedTypeToTypeArgs[generatedType];
                if (typeInfo.OriginalAttributes is not null)
                {
                    return;
                }
                var method = typeInfo.CreatingMethod;
                if (method.Body is { } body)
                {
                    var typeArgs = new ICustomAttributeProvider[generatedType.GenericParameters.Count];
                    var typeRef = ScanForInit(generatedType, body, context);
                    if (typeRef is null)
                    {
                        return;
                    }
 
                    for (int i = 0; i < typeRef.GenericArguments.Count; i++)
                    {
                        var typeArg = typeRef.GenericArguments[i];
                        // Start with the existing parameters, in case we can't find the mapped one
                        ICustomAttributeProvider userAttrs = generatedType.GenericParameters[i];
                        // The type parameters of the state machine types are alpha renames of the
                        // the method parameters, so the type ref should always be a GenericParameter. However,
                        // in the case of nesting, there may be multiple renames, so if the parameter is a method
                        // we know we're done, but if it's another state machine, we have to keep looking to find
                        // the original owner of that state machine.
                        if (typeArg is GenericParameter { Owner: { } owner } param)
                        {
                            if (owner is MethodReference)
                            {
                                userAttrs = param;
                            }
                            else
                            {
                                // Must be a type ref
                                var owningRef = (TypeReference)owner;
                                if (!CompilerGeneratedNames.IsStateMachineOrDisplayClass(owningRef.Name))
                                {
                                    userAttrs = param;
                                }
                                else if (context.TryResolve((TypeReference)param.Owner) is { } owningType)
                                {
                                    MapGeneratedTypeTypeParameters(owningType, generatedTypeToTypeArgs, context);
                                    if (generatedTypeToTypeArgs[owningType].OriginalAttributes is { } owningAttrs)
                                    {
                                        userAttrs = owningAttrs[param.Position];
                                    }
                                    else
                                    {
                                        Debug.Assert(false, "This should be impossible in valid code");
                                    }
                                }
                            }
                        }
 
                        typeArgs[i] = userAttrs;
                    }
 
                    generatedTypeToTypeArgs[generatedType] = typeInfo with { OriginalAttributes = typeArgs };
                }
            }
 
            static GenericInstanceType? ScanForInit(
                TypeDefinition compilerGeneratedType,
                MethodBody body,
                LinkContext context)
            {
                foreach (var instr in context.GetMethodIL(body).Instructions)
                {
                    bool handled = false;
                    switch (instr.OpCode.Code)
                    {
                        case Code.Initobj:
                        case Code.Newobj:
                        {
                            if (instr.Operand is MethodReference { DeclaringType: GenericInstanceType typeRef }
                                && compilerGeneratedType == context.TryResolve(typeRef))
                            {
                                return typeRef;
                            }
                            handled = true;
                        }
                        break;
                        case Code.Stsfld:
                        case Code.Ldsfld:
                        {
                            if (instr.Operand is FieldReference { DeclaringType: GenericInstanceType typeRef }
                                && compilerGeneratedType == context.TryResolve(typeRef))
                            {
                                return typeRef;
                            }
                            handled = true;
                        }
                        break;
                    }
 
                    // Also look for type substitutions into generic methods
                    // (such as AsyncTaskMethodBuilder::Start<TStateMachine>).
                    if (!handled && instr.OpCode.OperandType is OperandType.InlineMethod)
                    {
                        if (instr.Operand is GenericInstanceMethod gim)
                        {
                            foreach (var tr in gim.GenericArguments)
                            {
                                if (tr is GenericInstanceType git && compilerGeneratedType == context.TryResolve(git))
                                {
                                    return git;
                                }
                            }
                        }
                    }
                }
                return null;
            }
        }
 
        static TypeDefinition? GetFirstConstructorArgumentAsType(CustomAttribute attribute)
        {
            if (!attribute.HasConstructorArguments)
                return null;
 
            return attribute.ConstructorArguments[0].Value as TypeDefinition;
        }
 
        public bool TryGetCompilerGeneratedCalleesForUserMethod(MethodDefinition method, [NotNullWhen(true)] out List<IMemberDefinition>? callees)
        {
            callees = null;
            if (IsNestedFunctionOrStateMachineMember(method))
                return false;
 
            var typeToCache = GetCompilerGeneratedStateForType(method.DeclaringType);
            if (typeToCache is null)
                return false;
 
            return _cachedTypeToCompilerGeneratedMembers[typeToCache]?.TryGetValue(method, out callees) == true;
        }
 
        /// <summary>
        /// Gets the attributes on the "original" method of a generated type, i.e. the
        /// attributes on the corresponding type parameters from the owning method.
        /// </summary>
        public IReadOnlyList<ICustomAttributeProvider>? GetGeneratedTypeAttributes(TypeDefinition generatedType)
        {
            Debug.Assert(CompilerGeneratedNames.IsStateMachineOrDisplayClass(generatedType.Name));
 
            var typeToCache = GetCompilerGeneratedStateForType(generatedType);
            if (typeToCache is null)
                return null;
 
            if (_generatedTypeToTypeArgumentInfo.TryGetValue(generatedType, out var typeInfo))
            {
                return typeInfo.OriginalAttributes;
            }
            return null;
        }
 
        // For state machine types/members, maps back to the state machine method.
        // For local functions and lambdas, maps back to the owning method in user code (not the declaring
        // lambda or local function, because the IL doesn't contain enough information to figure this out).
        public bool TryGetOwningMethodForCompilerGeneratedMember(IMemberDefinition sourceMember, [NotNullWhen(true)] out MethodDefinition? owningMethod)
        {
            owningMethod = null;
            if (sourceMember == null)
                return false;
 
            MethodDefinition? compilerGeneratedMethod = sourceMember as MethodDefinition;
            if (compilerGeneratedMethod != null)
            {
                if (_compilerGeneratedMethodToUserCodeMethod.TryGetValue(compilerGeneratedMethod, out owningMethod))
                    return true;
            }
 
            TypeDefinition sourceType = sourceMember as TypeDefinition ?? sourceMember.DeclaringType;
 
            if (_compilerGeneratedTypeToUserCodeMethod.TryGetValue(sourceType, out owningMethod))
                return true;
 
            if (!IsNestedFunctionOrStateMachineMember(sourceMember))
                return false;
 
            // sourceType is a state machine type, or the type containing a lambda or local function.
            // Search all methods to find the one which points to the type as its
            // state machine implementation.
            var typeToCache = GetCompilerGeneratedStateForType(sourceType);
            if (typeToCache is null)
                return false;
 
            if (compilerGeneratedMethod != null)
            {
                if (_compilerGeneratedMethodToUserCodeMethod.TryGetValue(compilerGeneratedMethod, out owningMethod))
                    return true;
            }
 
            if (_compilerGeneratedTypeToUserCodeMethod.TryGetValue(sourceType, out owningMethod))
                return true;
 
            return false;
        }
    }
}